batch_x, batch_y, path = data_train(opt.batchSize)
    data = {'A': batch_x, 'A_paths': path, 'B': batch_y, 'B_paths': path}
    model.set_input(data)
    model.test()

    # Berechnung des Maximums des Softmax Layers für jeden Pixel, um die vom
    # Netzwerk berechnete Segmentierung zu erhalten
    out = torch.argmax(model.fake_B2[0], dim=0)
    background = (out == 0).float()
    myelin = (out == 1).float()
    axon = (out == 2).float()
    map = torch.stack((background, myelin, axon))

    # Berechnung der Bewertungen
    score_dice = dice(map, batch_y[0, 0])
    score_iou = iou(map, batch_y[0, 0])
    dice_cumul += score_dice
    iou_cumul += score_iou

    # Visualisierung über visdom
    if step % opt.display_step == 0:
        print("Step %d | Dice: %f, IoU: %f" % (step, score_dice, score_iou))
        visualizer.display_current_results(getVisuals(model), 1, False)

# Berechnung der Bewertungsdurchschnitte
dice_complete = dice_cumul / dataset_size
iou_complete = iou_cumul / dataset_size
print("--------------------------------------------------------")
print('End score:')
print("Dice: %f, IoU: %f" % (dice_complete, iou_complete))
Beispiel #2
0
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        iter_end_time = time.time()
Beispiel #3
0
def run(opt):
    print("Number of GPUs used: {}".format(torch.cuda.device_count()))
    print("Current Experiment Name: {}".format(opt.name))

    # The dataloader will yield the training samples
    dataloader = data.create_dataloader(opt)

    trainer = TrainerManager(opt)
    inference_manager = InferenceManager(
        num_samples=opt.num_evaluation_samples,
        opt=opt,
        cuda=len(opt.gpu_ids) > 0,
        write_details=False,
        save_images=False)

    # For logging and visualizations
    iter_counter = IterationCounter(opt, len(dataloader))
    visualizer = Visualizer(opt)

    if not opt.debug:
        # We keep a copy of the current source code for each experiment
        copy_src(path_from="./",
                 path_to=os.path.join(opt.checkpoints_dir, opt.name))

    # We wrap training into a try/except clause such that the model is saved
    # when interrupting with Ctrl+C
    try:
        for epoch in iter_counter.training_epochs():
            iter_counter.record_epoch_start(epoch)
            for i, data_i in enumerate(dataloader,
                                       start=iter_counter.epoch_iter):

                # Training the generator
                if i % opt.D_steps_per_G == 0:
                    trainer.run_generator_one_step(data_i)

                # Training the discriminator
                trainer.run_discriminator_one_step(data_i)

                iter_counter.record_one_iteration()

                # Logging, plotting and visualizing
                if iter_counter.needs_printing():
                    losses = trainer.get_latest_losses()
                    visualizer.print_current_errors(
                        epoch, iter_counter.epoch_iter, losses,
                        iter_counter.time_per_iter,
                        iter_counter.total_time_so_far,
                        iter_counter.total_steps_so_far)
                    visualizer.plot_current_errors(
                        losses, iter_counter.total_steps_so_far)

                if iter_counter.needs_displaying():
                    logs = trainer.get_logs()
                    visuals = [('input_label', data_i['label']),
                               ('out_train', trainer.get_latest_generated()),
                               ('real_train', data_i['image'])]
                    if opt.guiding_style_image:
                        visuals.append(
                            ('guiding_image', data_i['guiding_image']))
                        visuals.append(
                            ('guiding_input_label', data_i['guiding_label']))

                    if opt.evaluate_val_set:
                        validation_output = inference_validation(
                            trainer.sr_model, inference_manager, opt)
                        visuals += validation_output
                    visuals = OrderedDict(visuals)
                    visualizer.display_current_results(
                        visuals, epoch, iter_counter.total_steps_so_far, logs)

                if iter_counter.needs_saving():
                    print(
                        'Saving the latest model (epoch %d, total_steps %d)' %
                        (epoch, iter_counter.total_steps_so_far))
                    trainer.save('latest')
                    iter_counter.record_current_iter()

                if iter_counter.needs_evaluation():
                    # Evaluate on training set
                    result_train = evaluate_training_set(
                        inference_manager, trainer.sr_model_on_one_gpu,
                        dataloader)
                    info = iter_counter.record_fid(
                        result_train["FID"],
                        split="train",
                        num_samples=opt.num_evaluation_samples)
                    info += os.linesep + iter_counter.record_metrics(
                        result_train, split="train")
                    visualizer.plot_current_errors(
                        result_train,
                        iter_counter.total_steps_so_far,
                        split="train/")

                    if opt.evaluate_val_set:
                        # Evaluate on validation set
                        result_val = evaluate_validation_set(
                            inference_manager, trainer.sr_model_on_one_gpu,
                            opt)
                        info += os.linesep + iter_counter.record_fid(
                            result_val["FID"],
                            split="validation",
                            num_samples=opt.num_evaluation_samples)
                        info += os.linesep + iter_counter.record_metrics(
                            result_val, split="validation")
                        visualizer.plot_current_errors(
                            result_val,
                            iter_counter.total_steps_so_far,
                            split="validation/")

            trainer.update_learning_rate(epoch)
            iter_counter.record_epoch_end()

            if epoch % opt.save_epoch_freq == 0 or \
                            epoch == iter_counter.total_epochs:
                print('Saving the model at the end of epoch %d, iters %d' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                trainer.save(epoch)
                iter_counter.record_current_iter()

        print('Training was successfully finished.')
    except (KeyboardInterrupt, SystemExit):
        print("KeyboardInterrupt. Shutting down.")
        print(traceback.format_exc())
    except Exception as e:
        print(traceback.format_exc())
    finally:
        print('Saving the model before quitting')
        trainer.save('latest')
        iter_counter.record_current_iter()
Beispiel #4
0
        trainer.run_discriminator_one_step(data_i)

        # Visualizations
        if iter_counter.needs_printing():
            losses = trainer.get_latest_losses()
            visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                            losses, iter_counter.time_per_iter)
            visualizer.plot_current_errors(losses,
                                           iter_counter.total_steps_so_far)

        if iter_counter.needs_displaying():
            visuals = OrderedDict([('input_label', data_i['label']),
                                   ('synthesized_image',
                                    trainer.get_latest_generated()),
                                   ('real_image', data_i['image'])])
            visualizer.display_current_results(visuals, epoch,
                                               iter_counter.total_steps_so_far)

        if iter_counter.needs_saving():
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

    trainer.update_learning_rate(epoch)
    iter_counter.record_epoch_end()

    if epoch % opt.save_epoch_freq == 0 or \
       epoch == iter_counter.total_epochs:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, iter_counter.total_steps_so_far))
        trainer.save('latest')
Beispiel #5
0
#################################################
print('step 3')
total_steps = 0
for epoch in range(1, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    # You can use paired and unpaired data to train. Here we only use paired samples to train.
    for i, (images_a, images_b) in enumerate(dataset_paired):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter = total_steps - paired_dataset_size * (epoch - 1)
        model.set_input(images_a, images_b)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visuals = model.get_current_visuals()
            visualizer.display_current_results(visuals, epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            visualizer.print_current_errors(epoch, epoch_iter, errors,
                                            iter_start_time)
            if opt.display_id > 0:
                visualizer.plot_current_errors(
                    epoch,
                    float(epoch_iter) / paired_dataset_size, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')
Beispiel #6
0
def main_worker(gpu, world_size, idx_server, opt):
    print('Use GPU: {} for training'.format(gpu))
    ngpus_per_node = world_size
    world_size = opt.world_size
    rank = idx_server * ngpus_per_node + gpu
    opt.gpu = gpu
    dist.init_process_group(backend='nccl', init_method=opt.dist_url, world_size=world_size, rank=rank)
    torch.cuda.set_device(opt.gpu)

    # load the dataset
    dataloader = data.create_dataloader(opt, world_size, rank)
    
    # create trainer for our model
    trainer = Pix2PixTrainer(opt)
    
    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader), world_size, rank)
    
    # create tool for visualization
    visualizer = Visualizer(opt, rank)
    
    for epoch in iter_counter.training_epochs():
        # set epoch for data sampler
        dataloader.sampler.set_epoch(epoch)

        iter_counter.record_epoch_start(epoch)

        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()
    
            # Training
            # train generator
            trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

        visuals = OrderedDict([('input_label', data_i['label']),
                               ('synthesized_image', trainer.get_latest_generated()),
                               ('real_image', data_i['image'])])
        visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

        if rank == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if (epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs) and (rank == 0):
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save(epoch)
    
    print('Training was successfully finished.')
Beispiel #7
0
model = create_model(opt)
visualizer = Visualizer(opt)

total_steps = 0

for epoch in range(1, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter = total_steps - dataset_size * (epoch - 1)
        model.set_input(data)  
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_img(), epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            # if opt.display_id > 0:
            #     visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
Beispiel #8
0
    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        if total_steps % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        
        visualizer.reset()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)
        model.optimize_ipman()
        
        if total_steps % opt.display_freq == 0:
            visuals = model.get_current_visuals()
            # del visuals['real_B']
            save_result = total_steps % opt.update_html_freq == 0
            visualizer.display_current_results(visuals, epoch, save_result)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_ipman_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t,
                                            t_data)
            if opt.display_id > 0:
                visualizer.plot_current_errors(
                    epoch,
                    float(epoch_iter) / dataset_size, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch {}, total_steps {})'.format(
                epoch, total_steps))
            model.save('optim_latest')
Beispiel #9
0
        generated = model.inference(data['dp_target'][0], data['source_frame'],
                                    data['source_frame'],
                                    data['grid_source'][0],
                                    data['grid_source'][0])

        img_path = data['path'][0]
        frame_number = str(0)
        print generated.size()
        print('process image... %s' % img_path + "   " + str(0))
        visualizer.save_images(webpage,
                               util.tensor2im(generated.squeeze(dim=0)),
                               img_path, frame_number)
        visuals = OrderedDict([('synthesized_image',
                                util.tensor2im(generated.squeeze(dim=0)))])
        visualizer.display_current_results(visuals, 100, 12345)

        for i in range(1, data["dp_target"].shape[0]):
            if opt.prev_frame_num == 0:
                generated = model.inference(data['dp_target'][i],
                                            data['source_frame'],
                                            data['source_frame'],
                                            data['grid_source'][i],
                                            data['grid_source'][i])
            else:
                generated = model.inference(data['dp_target'][i],
                                            data['source_frame'], generated,
                                            data['grid_source'][i],
                                            data['grid'][i - 1])

            img_path = data['path'][0]
Beispiel #10
0
def train_main(raw_args=None):
    # print(torch.backends.cudnn.benchmark)
    opt = TrainOptions().parse(raw_args)  # get training options
    if opt.debug_mode:
        import multiprocessing
        multiprocessing.set_start_method('spawn', True)
        opt.num_threads = 0

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    existing_epochs = glob.glob(opt.checkpoints_dir + '/' + opt.name +
                                '/*[0-9]_net_G_A.pth')
    if opt.restart_training and len(existing_epochs) > 0:
        opt.epoch = int(
            os.path.splitext(os.path.basename(
                existing_epochs[-1]))[0].split('_')[0])
        opt.epoch_count = opt.epoch + 1

    plot_losses_from_log_files(opt,
                               dataset_size,
                               domain=['A', 'B'],
                               specified=['G', 'D', 'cycle'])

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' %
                      (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate(
        )  # update learning rates at the end of every epoch.
Beispiel #11
0
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0])),
                    ('sketch', util.tensor2im(data['inst'][0]))
                ])
            else:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])

            visualizer.display_current_results(
                visuals, epoch, total_steps,
                [data['size'][0][0], data['size'][1][0]])

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.module.save('latest')
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

        if epoch_iter >= dataset_size:
            break

    # end of epoch
    iter_end_time = time.time()
    print('End of epoch %d / %d \t Time Taken: %d sec' %
Beispiel #12
0
def train():
    np.set_printoptions(threshold=sys.maxsize)

    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
    parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
    parser.add_argument("--dataset_name", type=str, default="leftkidney_3d", help="name of the dataset")
    parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
    parser.add_argument("--glr", type=float, default=0.0002, help="adam: generator learning rate")
    parser.add_argument("--dlr", type=float, default=0.0002, help="adam: discriminator learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_height", type=int, default=128, help="size of image height")
    parser.add_argument("--img_width", type=int, default=128, help="size of image width")
    parser.add_argument("--img_depth", type=int, default=128, help="size of image depth")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--disc_update", type=int, default=5, help="only update discriminator every n iter")
    parser.add_argument("--d_threshold", type=int, default=.8, help="discriminator threshold")
    parser.add_argument("--threshold", type=int, default=-1, help="threshold during sampling, -1: No thresholding")
    parser.add_argument(
        "--sample_interval", type=int, default=1, help="interval between sampling of images from generators"
    )
    parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between model checkpoints")
    parser.add_argument('--checkpoints_dir', type=str, default='./saved_models', help='models are saved here')
    parser.add_argument('--images_dir', type=str, default='./images', help='images are saved here')
    parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
    parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
    parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
    parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
    parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
    opt = parser.parse_args()
    print(opt)

    os.makedirs(f"{opt.images_dir}/{opt.dataset_name}", exist_ok=True)
    os.makedirs(f"{opt.checkpoints_dir}/{opt.dataset_name}", exist_ok=True)

    cuda = True if torch.cuda.is_available() else False

    visualizer = Visualizer(opt) # Create visualizer to display and save training images

    # Loss functions
    criterion_GAN = torch.nn.MSELoss()
    criterion_voxelwise = diceloss()

    # Loss weight of L1 voxel-wise loss between translated image and real image
    lambda_voxel = 100

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4, opt.img_depth // 2 ** 4)

    # Initialize generator and discriminator
    generator = GeneratorUNet()
    discriminator = Discriminator()

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        criterion_GAN.cuda()
        criterion_voxelwise.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))
        discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.glr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.dlr, betas=(opt.b1, opt.b2))

    # Configure dataloaders
    transforms_ = transforms.Compose([
        # transforms.Resize((opt.img_height, opt.img_width, opt.img_depth), Image.BICUBIC),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataloader = DataLoader(
        CTDataset("./data/%s/train/" % opt.dataset_name, transforms_=transforms_),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )

    # TODO: Once we're using actual validation data, change this to reflect the data
    val_dataloader = DataLoader(
        CTDataset("./data/%s/train/" % opt.dataset_name, transforms_=transforms_),
        batch_size=1,
        shuffle=True,
        num_workers=1,
    )

    # Tensor type
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


    def sample_voxel_volumes(epoch):
        """Saves a generated sample from the validation set"""
        imgs = next(iter(val_dataloader))
        real_A = Variable(imgs["A"].unsqueeze_(1).type(Tensor))
        real_B = Variable(imgs["B"].unsqueeze_(1).type(Tensor))
        fake_B = generator(real_A)

        # convert to numpy arrays
        real_A = real_A.cpu().detach().numpy()
        real_B = real_B.cpu().detach().numpy()
        fake_B = fake_B.cpu().detach().numpy()

        image_folder = "images/%s/epoch_%s_" % (opt.dataset_name, epoch)

        hf = h5py.File(image_folder + 'real_A.vox', 'w')
        hf.create_dataset('data', data=real_A)

        hf1 = h5py.File(image_folder + 'real_B.vox', 'w')
        hf1.create_dataset('data', data=real_B)

        hf2 = h5py.File(image_folder + 'fake_B.vox', 'w')
        hf2.create_dataset('data', data=fake_B)

    # ----------
    #  Training
    # ----------

    prev_time = time.time()
    discriminator_update = 'False'
    for epoch in range(opt.epoch, opt.n_epochs):
        visualizer.reset() # reset the visualizer: make sure results are saved once per epoch

        for i, batch in enumerate(dataloader):

            # Model inputs
            real_A = Variable(batch["A"].unsqueeze_(1).type(Tensor))
            real_B = Variable(batch["B"].unsqueeze_(1).type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)


            # ---------------------
            #  Train Discriminator, only update every disc_update batches
            # ---------------------
            # Real loss
            fake_B = generator(real_A)
            pred_real = discriminator(real_B, real_A)
            loss_real = criterion_GAN(pred_real, valid)

            # Fake loss
            pred_fake = discriminator(fake_B.detach(), real_A)
            loss_fake = criterion_GAN(pred_fake, fake)
            # Total loss
            loss_D = 0.5 * (loss_real + loss_fake)

            d_real_acu = torch.ge(pred_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(pred_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            if d_total_acu <= opt.d_threshold:
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                discriminator_update = 'True'

            # ------------------
            #  Train Generators
            # ------------------
            optimizer_D.zero_grad()
            optimizer_G.zero_grad()

            # GAN loss
            fake_B = generator(real_A)
            pred_fake = discriminator(fake_B, real_A)
            loss_GAN = criterion_GAN(pred_fake, valid)
            # Voxel-wise loss
            loss_voxel = criterion_voxelwise(fake_B, real_B)

            # Total loss
            loss_G = loss_GAN + lambda_voxel * loss_voxel

            loss_G.backward()

            optimizer_G.step()

            batches_done = epoch * len(dataloader) + i

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f, D accuracy: %f, D update: %s] [G loss: %f, voxel: %f, adv: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    d_total_acu,
                    discriminator_update,
                    loss_G.item(),
                    loss_voxel.item(),
                    loss_GAN.item(),
                    time_left,
                )
            )
            # If at sample interval save image
            if batches_done % (opt.sample_interval*len(dataloader)) == 0:
                sample_voxel_volumes(epoch)
                print('*****volumes sampled*****')

            discriminator_update = 'False'

            # --------------------------
            #  Display images in visdom
            # --------------------------
            visuals = {}
            vis_axis = 'z'
            visuals['Real A'] = getSlice(real_A, vis_axis)
            visuals['Real B'] = getSlice(real_B, vis_axis)
            visuals['Fake B'] = getSlice(fake_B, vis_axis)
            visualizer.display_current_results(visuals, epoch, False)

            # --------------------------
            #  Display losses in visdom
            losses = {}
            losses['D accuracy'] = float(d_total_acu)
            losses['D loss'] = loss_D.item()
            losses['G loss'] = loss_G.item()
            visualizer.plot_current_losses(epoch, float(i) / len(dataloader), losses)

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))
            torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(
                    model.get_current_visuals(), epoch, save_result,
                    total_iters / opt.update_html_freq)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                for k, v in losses.items():
                    epoch_dic[k].append(v)  #save losses into dictionary
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
Beispiel #14
0
model = ComboGANModel(opt)
visualizer = Visualizer(opt)
total_steps = 0

for epoch in range(opt.which_epoch + 1, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    epoch_iter = 0
    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(), epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
Beispiel #15
0
if not os.path.exists(single_save_url):
    os.makedirs(single_save_url)

for i, data_i in enumerate(dataloader):
    if i * opt.batchSize >= opt.how_many:
        break

    generated = model(data_i, mode='inference')

    img_path = data_i['path']

    visuals = OrderedDict([('input_label', data_i['label']),
                           ('input_image', data_i['image']),
                           ('synthesized_image', generated)])

    visualizer.display_current_results(visuals, None, i)

    ### Also save the single output image

    for b in range(generated.shape[0]):

        img_name = img_path[b].split('/')[-1]
        save_img_url = os.path.join(single_save_url, img_name)

        print('save image... %s' % save_img_url)

        vutils.save_image((generated[b] + 1) / 2, save_img_url)

    # for b in range(generated.shape[0]):
    #     print('process image... %s' % img_path[b])
    #     visuals = OrderedDict([('input_label', data_i['label'][b]),
def sanity_check(opt):
    abort_file = "/mnt/raid/patrickradner/kill" + str(opt.gpu_ids[0]) if len(
        opt.gpu_ids) > 0 else "cpu"

    if os.path.exists(abort_file):
        os.remove(abort_file)
        exit("Abort using file: " + abort_file)

    opt.max_dataset_size = 1
    opt.max_val_dataset_size = 1
    freq = 10
    opt.batch_size = 1
    opt.print_freq = freq
    opt.display_freq = freq
    opt.update_html_freq = freq
    opt.validation_freq = 50
    opt.niter = 500
    opt.niter_decay = 0
    opt.display_env = "sanity_check"
    opt.num_display_frames = 10
    opt.train_mode = "frame"
    #opt.reparse_data=True
    opt.lr = 0.004
    opt.pretrain_epochs = 0

    opt.verbose = True

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    t_min = 100000
    t_max = 0

    print(f"Length: Min: {t_min} Max: {t_max}")

    if opt.validation_freq > 0:
        phase = opt.phase
        opt.phase = opt.validation_set
        validation_loader = CreateDataLoader(opt)
        validation_set = validation_loader.load_data()
        opt.phase = phase

        validation_size = len(validation_loader)
        print('#training samples = %d' % dataset_size)
        print('#validation samples = %d' % validation_size)

    model = create_model(opt)
    model.setup(opt)

    visualizer = Visualizer(opt)
    total_steps = 0

    data = next(iter(dataset))

    for epoch in range(5000):
        # training loop

        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0
        losses = {}

        if os.path.exists(abort_file):
            exit("Abort using file: " + abort_file)

        iter_start_time = time.time()
        if total_steps % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        visualizer.reset()
        total_steps += opt.batch_size
        epoch_iter += opt.batch_size

        model.set_input(data)
        model.optimize_parameters(epoch, verbose=opt.verbose)

        if total_steps % opt.display_freq == 0:
            save_result = total_steps % opt.update_html_freq == 0
            visualizer.display_current_results(model.get_current_visuals(),
                                               epoch, save_result)

        if total_steps % opt.print_freq == 0:
            losses = model.get_current_losses()
            t = (time.time() - iter_start_time) / opt.batch_size
            visualizer.print_current_losses(epoch, epoch_iter, losses, t,
                                            t_data)
            if opt.display_id > 0:
                visualizer.plot_current_losses(
                    epoch,
                    float(epoch_iter) / dataset_size, opt, losses)

        iter_data_time = time.time()

        if epoch % 50 == 0:
            print('End of sanity_check epoch %d / %d \t Time Taken: %d sec' %
                  (epoch, opt.niter + opt.niter_decay,
                   time.time() - epoch_start_time))

    print("SANITY CHECK DONE")
Beispiel #17
0
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(),
                                                   total_iters, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter,
                                                model.get_current_losses(),
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size,
                        model.get_current_losses())
            iter_data_time = time.time()

            if total_iters % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' %
Beispiel #18
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    if opt.dataset_mode == 'pose':
        print('#training frames = %d' % dataset_size)
    else:
        print('#training videos = %d' % dataset_size)

    ### initialize models
    modelG, modelD, flowNet = create_model(opt)
    visualizer = Visualizer(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    ### if continue training, recover previous states
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        if start_epoch > opt.niter:
            modelG.module.update_learning_rate(start_epoch - 1)
            modelD.module.update_learning_rate(start_epoch - 1)
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                start_epoch > opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch > opt.niter_step:
            data_loader.dataset.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
            modelG.module.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
    else:
        start_epoch, epoch_iter = 1, 0

    ### set parameters
    n_gpus = opt.n_gpus_gen // opt.batchSize  # number of gpus used for generator for each batch
    tG, tD = opt.n_frames_G, opt.n_frames_D
    tDB = tD * opt.output_nc
    s_scales = opt.n_scales_spatial
    t_scales = opt.n_scales_temporal
    input_nc = 1 if opt.label_nc != 0 else opt.input_nc
    output_nc = opt.output_nc

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    total_steps = total_steps // opt.print_freq * opt.print_freq

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            _, n_frames_total, height, width = data['B'].size(
            )  # n_frames_total = n_frames_load * n_loadings + tG - 1
            n_frames_total = n_frames_total // opt.output_nc
            n_frames_load = opt.max_frames_per_gpu * n_gpus  # number of total frames loaded into GPU at a time for each batch
            n_frames_load = min(n_frames_load, n_frames_total - tG + 1)
            t_len = n_frames_load + tG - 1  # number of loaded frames plus previous frames

            fake_B_last = None  # the last generated frame from previous training batch (which becomes input to the next batch)
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            real_B_skipped, fake_B_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled frames
            flow_ref_skipped, conf_ref_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled flows

            for i in range(0, n_frames_total - t_len + 1, n_frames_load):
                # 5D tensor: batchSize, # of frames, # of channels, height, width
                input_A = Variable(
                    data['A'][:, i * input_nc:(i + t_len) * input_nc,
                              ...]).view(-1, t_len, input_nc, height, width)
                input_B = Variable(
                    data['B'][:, i * output_nc:(i + t_len) * output_nc,
                              ...]).view(-1, t_len, output_nc, height, width)
                inst_A = Variable(data['inst'][:, i:i + t_len, ...]).view(
                    -1, t_len, 1, height,
                    width) if len(data['inst'].size()) > 2 else None

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames

                ####### discriminator
                ### individual frame discriminator
                flow_ref, conf_ref = flowNet(
                    real_B[:, :, :3, ...],
                    real_B_prev[:, :, :3,
                                ...])  # reference flows and confidences
                fake_B_prev = real_B_prev[:, 0:
                                          1] if fake_B_last is None else fake_B_last[
                                              0][:, -1:]
                if fake_B.size()[1] > 1:
                    fake_B_prev = torch.cat(
                        [fake_B_prev, fake_B[:, :-1].detach()], dim=1)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                loss_dict_T = []
                # get skipped frames for each temporal scale
                if t_scales > 0:
                    real_B_all, real_B_skipped = get_skipped_frames(
                        real_B_all, real_B, t_scales, tD)
                    fake_B_all, fake_B_skipped = get_skipped_frames(
                        fake_B_all, fake_B, t_scales, tD)
                    flow_ref_all, conf_ref_all, flow_ref_skipped, conf_ref_skipped = get_skipped_flows(
                        flowNet, flow_ref_all, conf_ref_all, real_B_skipped,
                        flow_ref, conf_ref, t_scales, tD)

                # run discriminator for each temporal scale
                for s in range(t_scales):
                    if real_B_skipped[s] is not None and real_B_skipped[
                            s].size()[1] == tD:
                        losses = modelD(s + 1, [
                            real_B_skipped[s], fake_B_skipped[s],
                            flow_ref_skipped[s], conf_ref_skipped[s]
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + loss_dict[
                    'G_GAN_Feat'] + loss_dict['G_VGG']
                loss_G += loss_dict['G_Warp'] + loss_dict[
                    'F_Flow'] + loss_dict['F_Warp'] + loss_dict['W']
                if opt.add_face_disc:
                    loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
                    loss_D += (loss_dict['D_f_fake'] +
                               loss_dict['D_f_real']) * 0.5

                # collect temporal losses
                loss_D_T = []
                t_scales_act = min(t_scales, len(loss_dict_T))
                for s in range(t_scales_act):
                    loss_G += loss_dict_T[s]['G_T_GAN'] + loss_dict_T[s][
                        'G_T_GAN_Feat'] + loss_dict_T[s]['G_T_Warp']
                    loss_D_T.append((loss_dict_T[s]['D_T_fake'] +
                                     loss_dict_T[s]['D_T_real']) * 0.5)

                ###################################### Backward Pass #################################
                optimizer_G = modelG.module.optimizer_G
                optimizer_D = modelD.module.optimizer_D
                # update generator weights
                optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()

                # update discriminator weights
                # individual frame discriminator
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                # temporal discriminator
                for s in range(t_scales_act):
                    optimizer_D_T = getattr(modelD.module,
                                            'optimizer_D_T' + str(s))
                    optimizer_D_T.zero_grad()
                    loss_D_T[s].backward()
                    optimizer_D_T.step()

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == 0:
                t = (time.time() - iter_start_time) / opt.print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                if opt.label_nc != 0:
                    input_image = util.tensor2label(real_A[0, -1],
                                                    opt.label_nc)
                elif opt.dataset_mode == 'pose':
                    input_image = util.tensor2im(real_A[0, -1, :3],
                                                 normalize=False)
                    if real_A.size()[2] == 6:
                        input_image2 = util.tensor2im(real_A[0, -1, 3:],
                                                      normalize=False)
                        input_image[input_image2 != 0] = input_image2[
                            input_image2 != 0]
                else:
                    c = 3 if opt.input_nc == 3 else 1
                    input_image = util.tensor2im(real_A[0, -1, :c],
                                                 normalize=False)
                if opt.use_instance:
                    edges = util.tensor2im(real_A[0, -1, -1:, ...],
                                           normalize=False)
                    input_image += edges[:, :, np.newaxis]

                if opt.add_face_disc:
                    ys, ye, xs, xe = modelD.module.get_face_region(real_A[0,
                                                                          -1:])
                    if ys is not None:
                        input_image[ys, xs:xe, :] = input_image[
                            ye, xs:xe, :] = input_image[
                                ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255

                visual_list = [
                    ('input_image', util.tensor2im(real_A[0, -1])),
                    ('fake_image', util.tensor2im(fake_B[0, -1])),
                    ('fake_first_image', util.tensor2im(fake_B_first)),
                    ('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
                    ('real_image', util.tensor2im(real_B[0, -1])),
                    ('flow_ref', util.tensor2flow(flow_ref[0, -1])),
                    ('conf_ref',
                     util.tensor2im(conf_ref[0, -1], normalize=False))
                ]
                if flow is not None:
                    visual_list += [('flow', util.tensor2flow(flow[0, -1])),
                                    ('weight',
                                     util.tensor2im(weight[0, -1],
                                                    normalize=False))]
                visuals = OrderedDict(visual_list)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == 0:
                visualizer.vis_print(
                    'saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, total_steps))
                modelG.module.save('latest')
                modelD.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            visualizer.vis_print(
                'saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
            modelG.module.save('latest')
            modelD.module.save('latest')
            modelG.module.save(epoch)
            modelD.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            modelG.module.update_learning_rate(epoch)
            modelD.module.update_learning_rate(epoch)

        ### gradually grow training sequence length
        if (epoch % opt.niter_step) == 0:
            data_loader.dataset.update_training_batch(epoch // opt.niter_step)
            modelG.module.update_training_batch(epoch // opt.niter_step)

        ### finetune all scales
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                epoch == opt.niter_fix_global):
            modelG.module.update_fixed_params()
def train_pose2vid(target_dir, run_name, temporal_smoothing=False):
    import src.config.train_opt as opt

    opt = update_opt(opt, target_dir, run_name, temporal_smoothing)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.json')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.load_pretrain != '':
        with open(iter_path, 'r') as f:
            iter_json = json.load(f)
    else:
        iter_json = {'start_epoch': 1, 'epoch_iter': 0}

    start_epoch = iter_json['start_epoch']
    epoch_iter = iter_json['epoch_iter']
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.to(device)
    visualizer = Visualizer(opt)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            if temporal_smoothing:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          Variable(data['previous_label']),
                                          Variable(data['previous_image']),
                                          infer=save_fake)
            else:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            ############## Display results and errors ##########

            print(f"Epoch {epoch} batch {i}:")
            print(f"loss_D: {loss_D}, loss_G: {loss_G}")
            print(
                f"loss_D_fake: {loss_dict['D_fake']}, loss_D_real: {loss_dict['D_real']}"
            )
            print(
                f"loss_G_GAN {loss_dict['G_GAN']}, loss_G_GAN_Feat: {loss_dict.get('G_GAN_Feat', 0)}, loss_G_VGG: {loss_dict.get('G_VGG', 0)}\n"
            )

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                # errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')
                iter_json['start_epoch'] = epoch
                iter_json['epoch_iter'] = epoch_iter
                with open(iter_path, 'w') as f:
                    json.dump(iter_json, f)

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            iter_json['start_epoch'] = epoch + 1
            iter_json['epoch_iter'] = 0
            with open(iter_path, 'w') as f:
                json.dump(iter_json, f)

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Beispiel #20
0
class Tester:
    def __init__(self, opt, dataset_key='test', visualizer=None):
        self.opt = deepcopy(opt)

        self.opt.serial_batches = True
        self.opt.no_flip = True
        self.opt.isTrain = False
        self.opt.dataset_key = dataset_key

        if 'results_dir' not in self.opt:
            self.opt.results_dir = 'results/'

        self.dataloader = data.create_dataloader(self.opt)

        self.visualizer = Visualizer(
            self.opt) if visualizer is None else visualizer

        base_path = os.getcwd()
        if self.opt.checkpoints_dir.startswith("./"):
            self.opt.checkpoints_dir = os.path.join(
                base_path, self.opt.checkpoints_dir[2:])
        else:
            self.opt.checkpoints_dir = os.path.join(base_path,
                                                    self.opt.checkpoints_dir)

        self.is_validation = self.opt.dataset_key in ["val", "train"]
        self.N = self.dataloader.dataset.N

        self.results_dir = os.path.join(opt.checkpoints_dir, self.opt.name,
                                        self.opt.results_dir,
                                        self.opt.dataset_key)
        if not os.path.exists(self.results_dir):
            os.makedirs(self.results_dir)

    def forward(self, model, data_i):
        fake = model.forward(data_i, mode="inference").detach().cpu()
        fake_resized = ImageProcessor.to_255resized_imagebatch(fake,
                                                               as_tensor=True)
        return fake, fake_resized

    def get_iterator(self, dataloader, indices=None):
        """

        Args:
            indices: a list of indices that should be loaded from dataloader. If it is none, the iterator iterates
                over the entire dataset.

        Returns: iterator

        """
        if indices is None:
            for data_i in dataloader:
                yield data_i
        else:
            for i_val in indices:
                data_i = dataloader.dataset.get_particular(i_val)
                yield data_i

    def _prepare_error_log(self):
        error_log = h5py.File(
            os.path.join(self.results_dir,
                         f"error_log_{self.opt.dataset_key}.h5"), "w")
        error_log.create_dataset("error", shape=(self.N, ), dtype=np.float)
        error_log.create_dataset("user", shape=(self.N, ), dtype='S4')
        error_log.create_dataset("filename", shape=(self.N, ), dtype='S13')
        error_log.create_dataset("visualisation",
                                 shape=(self.N, 1, 380, 1000),
                                 dtype=np.uint8)
        return error_log

    def _write_error_log_batch(self, error_log, data_i, i, fake, errors):
        visualisation_data = {**data_i, "fake": fake}
        visuals = visualize_sidebyside(visualisation_data, error_list=errors)

        # We add the entire batch to the output file
        idx_from, idx_to = i * self.opt.batchSize, i * self.opt.batchSize + self.opt.batchSize
        error_log["user"][idx_from:idx_to] = np.array(data_i["user"],
                                                      dtype='S4')
        error_log["filename"][idx_from:idx_to] = np.array(data_i["filename"],
                                                          dtype='S13')
        error_log["error"][idx_from:idx_to] = errors
        vis = np.array([np.copy(v) for k, v in visuals.items()])
        # vis are all floats in [-1, 1]
        vis = (vis + 1) * 128
        error_log["visualisation"][idx_from:idx_to] = vis
        return error_log

    def run_batch(self, data_i, model):
        fake, fake_resized = self.forward(model, data_i)
        target_image = ImageProcessor.as_batch(data_i["target_original"],
                                               as_tensor=True)
        errors = np.array(
            MSECalculator.calculate_mse_for_images(fake_resized, target_image))
        return errors, fake, fake_resized, target_image

    def run_validation(self,
                       model,
                       generator,
                       limit=-1,
                       write_error_log=False):
        print(f"write error log: {write_error_log}")
        assert self.is_validation, "Must be in validation mode"
        if write_error_log:
            error_log = self._prepare_error_log()

        all_errors = list()

        counter = 0
        for i, data_i in enumerate(generator):
            counter += data_i['label'].shape[0]
            if counter > limit:
                break
            if i % 10 == 9:
                print(f"Processing batch {i}")
                print(
                    f"Error so far: {np.sum(all_errors) / len(all_errors) * 1471}"
                )
            errors, fake, fake_resized, target_image = self.run_batch(
                data_i, model)
            all_errors += list(errors)
            if write_error_log:
                error_log = self._write_error_log_batch(
                    error_log, data_i, i, fake, errors)

        if write_error_log:
            error_log.close()
        return all_errors

    def print_results(self,
                      all_errors,
                      errors_dict,
                      epoch='n.a.',
                      n_steps="n.a."):
        print("Validation Results")
        print("------------------")
        print(
            f"Error calculated on {len(all_errors)} / {self.dataloader.dataset.N} samples"
        )
        for k in sorted(errors_dict):
            print(f"  {k}, {errors_dict[k]:.2f}")
        print(
            f"  dataset_key: {self.opt.dataset_key}, model: {self.opt.name}, epoch: {epoch}, n_steps: {n_steps}"
        )

    def run_visual_validation(self, model, mode, epoch, n_steps, limit):
        print(f"Visualizing images for mode '{mode}'...")
        indices = self._get_validation_indices(mode, limit)
        generator = self.get_iterator(self.dataloader, indices=indices)

        result_list = list()
        error_list = list()
        for data_i in generator:
            # data_i = dataloader.dataset.get_particular(i_val)
            errors, fake, fake_resized, target_image = self.run_batch(
                data_i, model)
            data_i['fake'] = fake
            result_list.append(data_i)
            error_list.append(errors)
        error_list = np.array(error_list)
        error_list = error_list.reshape(-1)
        result = {
            k: [rl[k] for rl in result_list]
            for k in result_list[0].keys()
        }
        for key in [
                "style_image", "target", "target_original", "fake", "label"
        ]:
            result[key] = torch.cat(result[key], dim=0)

        visuals = visualize_sidebyside(
            result,
            log_key=f"{self.opt.dataset_key}/{mode}",
            w=200,
            h=320,
            error_list=error_list)
        self.visualizer.display_current_results(visuals, epoch, n_steps)

    def _get_validation_indices(self, mode, limit):
        if 'rand' in mode:
            validation_indices = self.dataloader.dataset.get_random_indices(
                limit)
        elif 'fix' in mode:
            # Use fixed validation indices
            validation_indices = self.dataloader.dataset.get_validation_indices(
            )[:limit]
        elif 'full' in mode:
            validation_indices = None
        else:
            raise ValueError(f"Invalid mode: {mode}")
        return validation_indices

    def run(self,
            model,
            mode,
            epoch=None,
            n_steps=None,
            limit=-1,
            write_error_log=False,
            log=False):
        print(f"Running validation for mode '{mode}'...")
        limit = limit if limit > 0 else self.dataloader.dataset.N
        indices = self._get_validation_indices(mode, limit)
        generator = self.get_iterator(self.dataloader, indices=indices)
        all_errors = self.run_validation(model,
                                         generator,
                                         limit=limit,
                                         write_error_log=write_error_log)

        errors_dict = MSECalculator.calculate_error_statistics(
            all_errors, mode=mode, dataset_key=self.opt.dataset_key)
        self.print_results(all_errors, errors_dict, epoch, n_steps)

        if log:
            self.log_visualizer(errors_dict, epoch, n_steps)

    def log_visualizer(self, errors_dict, epoch=0, total_steps_so_far=0):
        """

        Args:
            errors_dict: must contain
            epoch:
            total_steps_so_far:
            log_key:

        Returns:

        """
        self.visualizer.print_current_errors(epoch,
                                             total_steps_so_far,
                                             errors_dict,
                                             t=0)
        self.visualizer.plot_current_errors(errors_dict, total_steps_so_far)

    def run_test(self, model, limit=-1):
        filepaths = list()

        for i, data_i in enumerate(self.dataloader):
            if limit > 0 and i * self.opt.batchSize >= limit:
                break
            if i % 10 == 0:
                print(
                    f"Processing batch {i} (processed {self.opt.batchSize * i} images)"
                )

            # The test file names are only 12 characters long, so we have dot to remove
            img_filename = [re.sub(r'\.', '', f) for f in data_i['filename']]

            fake, fake_resized = self.forward(model, data_i)
            # We are testing
            for b in range(len(img_filename)):
                result_path = os.path.join(self.results_dir,
                                           img_filename[b] + ".npy")
                assert torch.min(fake_resized[b]) >= 0 and torch.max(
                    fake_resized[b]) <= 255
                np.save(result_path, np.copy(fake_resized[b]).astype(np.uint8))
                filepaths.append(result_path)

        # We are testing
        path_filepaths = os.path.join(self.results_dir, "pred_npy_list.txt")
        with open(path_filepaths, 'w') as f:
            for line in filepaths:
                f.write(line)
                f.write(os.linesep)
        print(f"Written {len(filepaths)} files. Filepath: {path_filepaths}")

    def run_partial_modes(self, model, epoch, n_steps, log, visualize_images,
                          limit):
        # for mode in ['fix', 'rand']:
        for mode in ['rand']:
            self.run(model=model,
                     mode=mode,
                     epoch=epoch,
                     n_steps=n_steps,
                     log=log,
                     limit=limit)
            if visualize_images:
                self.run_visual_validation(model,
                                           mode=mode,
                                           epoch=epoch,
                                           n_steps=n_steps,
                                           limit=4)
Beispiel #21
0
        epoch_start_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            visualizer.reset()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.set_input(data)
            model.optimize_parameters_STN()

            if total_steps % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(
                    model.get_current_visuals_STN(),
                    total_steps,
                    save_result,
                    opt.update_html_freq,
                    n_latest=opt.n_latest)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
Beispiel #22
0
def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        # compute resume lr
        if start_epoch > opt.niter:
            lrd_unit = opt.lr / opt.niter_decay
            resume_lr = opt.lr - (start_epoch - opt.niter) * lrd_unit
            opt.lr = resume_lr
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 2
        opt.print_freq = 2
        opt.niter = 3
        opt.niter_decay = 0
        opt.max_dataset_size = 1
        opt.valSize = 1

    ## Loading data
    # train data
    data_loader = CreateDataLoader(opt, isVal=False)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training images = %d' % dataset_size)
    # validation data
    data_loader = CreateDataLoader(opt, isVal=True)
    valset = data_loader.load_data()
    print('# validation images = %d' % len(data_loader))

    ## Loading model
    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D
                ] = amp.initialize(model,
                                   [model.optimizer_G, model.optimizer_D],
                                   opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()

            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            model = model.train()
            losses, generated, metrics = model(data['A'],
                                               data['B'],
                                               data['geometry'],
                                               infer=False)

            # sum per device losses and metrics
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + opt.gan_feat_weight * loss_dict.get(
                'G_GAN_Feat', 0) + opt.vgg_weight * loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                metrics_ = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in metric_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                visualizer.print_current_metrics(epoch, epoch_iter, metrics_,
                                                 t)
                visualizer.plot_current_metrics(metrics_, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                if opt.task_type == 'specular':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=1)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=1))
                    ])
                elif opt.task_type == 'low':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=2)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=2))
                    ])
                elif opt.task_type == 'high':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=3)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=3))
                    ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ###########################################################################################
        # validation at the end of each epoch
        val_start_time = time.time()
        metrics_val = []
        for _, val_data in enumerate(valset):
            model = model.eval()
            # model.half()
            generated, metrics = model(val_data['A'],
                                       val_data['B'],
                                       val_data['geometry'],
                                       infer=True)
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            metrics_ = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in metric_dict.items()
            }
            metrics_val.append(metrics_)
        # Print out losses
        metrics_val = visualizer.mean4dict(metrics_val)
        t = (time.time() - val_start_time) / opt.print_freq
        visualizer.print_current_metrics(epoch,
                                         epoch_iter,
                                         metrics_val,
                                         t,
                                         isVal=True)
        visualizer.plot_current_metrics(metrics_val, total_steps, isVal=True)
        # visualization
        if opt.task_type == 'specular':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=1)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=1))
            ])
        if opt.task_type == 'low':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=2)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=2))
            ])
        if opt.task_type == 'high':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=3)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=3))
            ])
        visualizer.display_current_results(visuals, epoch, epoch, isVal=True)
        ###########################################################################################

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
                step_num, 1, opt,
                OrderedDict([
                    ('photometric_cost', photometric_cost.data.cpu()[0]),
                    ('smoothness_cost', smoothness_cost.data.cpu()[0]),
                    ('cost', cost.data.cpu()[0])
                ]))

        if np.mod(step_num, opt.display_freq) == 0:
            # frame_vis = frames.data[:,1,:,:,:].permute(0,2,3,1).contiguous().view(-1,opt.imW, 3).cpu().numpy().astype(np.uint8)
            # depth_vis = vis_depthmap(inv_depths.data[:,1,:,:].contiguous().view(-1,opt.imW).cpu()).numpy().astype(np.uint8)
            frame_vis = frames.data.permute(
                1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
            depth_vis = vis_depthmap(inv_depths.data.cpu()).numpy().astype(
                np.uint8)
            visualizer.display_current_results(
                OrderedDict([('%s frame' % (opt.name), frame_vis),
                             ('%s inv_depth' % (opt.name), depth_vis)]), epoch)
            sio.savemat(
                os.path.join(opt.checkpoints_dir, 'depth_%s.mat' % (step_num)),
                {
                    'D': inv_depths.data.cpu().numpy(),
                    'I': frame_vis
                })

        if np.mod(step_num, opt.save_latest_freq) == 0:
            print("cache model....")
            lkvolearner.save_model(
                os.path.join(opt.checkpoints_dir, '%s_model.pth' % (epoch)))
            lkvolearner.cuda()
            print('..... saved')
Beispiel #24
0
def train_main(rank, world_size, opt):

    init_process(rank, world_size)
    torch.cuda.set_device(torch.device('cuda:{}'.format(rank)))

    dataset = create_dataset(opt, rank)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)    # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    if opt.max_dataset_size == float("inf"):
        opt.max_dataset_size =  int(dataset_size /opt.batch_size) * opt.batch_size

    model = create_model(opt, rank)      # create a model given opt.model and other options

    model.setup(opt)               # regular setup: load and print networks; create schedulers
    total_iters = 0                # the total number of training iterations

    if rank == 0:
        visualizer = Visualizer(opt)  # create a visualizer that display/save images and plots

    for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch

        if rank == 0:
            visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

            if rank == 0:
                # display images on visdom and save images to a HTML file
                if total_iters % opt.display_freq == 0:
                    save_result = total_iters % opt.update_html_freq == 0
                    model.compute_visuals()
                    visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

                # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                loss_log_data = {"Epoch": epoch}
                for loss_name, loss_value in losses.items():
                    loss_log_data[loss_name] = loss_value

                if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                    print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                    save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                    model.save_networks(save_suffix)


            if total_iters % opt.print_freq == 0:
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size

                if rank == 0:
                    visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                    if opt.display_id > 0:
                        visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            iter_data_time = time.time()

        if rank == 0:
            if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
                print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
                model.save_networks('latest')
                model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
        model.update_learning_rate()                     # update learning rates at the end of every epoch.

    dist.destroy_process_group()
Beispiel #25
0
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % opt.display_freq == 0:
                print(total_steps)
                save_result = total_steps % opt.update_html_freq == 0

                # whether in pose estimation mode
                if opt.pose_mode:
                    visualizer.display_current_results(
                        model.get_current_visuals(), epoch, save_result)
                else:
                    visualizer.display_current_results(
                        model.get_current_visuals(), epoch, save_result,
                        model.grid_vis.detach().cpu())

            if total_steps % opt.print_freq == 0:

                losses = model.get_current_losses()
                if opt.use_val and opt.epoch_count != epoch:
                    losses['val_rmse'] = avg_dist
                else:
                    losses['val_rmse'] = 0

                t = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t,
def train_pix2pix(opt=train_opt):
    ''' Train pix2pix model '''

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    print(f'# training images = {len(data_loader)}')

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.cuda()

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print(f'saving the latest model (epoch {epoch}, total_steps {total_steps})')
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print(f'saving the model at the end of epoch {epoch}, iters {total_steps}')
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Beispiel #27
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1    
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)    
    print('#training videos = %d' % dataset_size)

    ### initialize models
    models = create_model(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = create_optimizer(opt, models)

    ### set parameters    
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)    

    ### real training starts here  
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()    
        for idx, data in enumerate(dataset, start=epoch_iter):        
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params(data, n_gpus, tG)
            fake_B_last, frames_all = data_loader.dataset.init_data(t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, inst_A = data_loader.dataset.prepare_data(data, i, input_nc, output_nc)
                
                ###################################### Forward Pass ##########################
                ####### generator                  
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(input_A, input_B, inst_A, fake_B_last)

                ####### discriminator            
                ### individual frame discriminator          
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]   # the collection of previous and current real frames
                flow_ref, conf_ref = flowNet(real_B, real_B_prev)       # reference flows and confidences
                #flow_ref, conf_ref = util.remove_dummy_from_tensor([flow_ref, conf_ref])
                fake_B_prev = modelG.module.compute_fake_B_prev(real_B_prev, fake_B_last, fake_B)
               
                losses = modelD(0, reshape([real_B, fake_B, fake_B_raw, real_A, real_B_prev, fake_B_prev, flow, weight, flow_ref, conf_ref]))
                losses = [ torch.mean(x) if x is not None else 0 for x in losses ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator                
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = modelD.module.get_all_skipped_frames(frames_all, \
                        real_B, fake_B, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet)                                

                # run discriminator for each temporal scale
                loss_dict_T = []
                for s in range(t_scales):                
                    if frames_skipped[0][s] is not None:                        
                        losses = modelD(s+1, [frame_skipped[s] for frame_skipped in frames_skipped])
                        losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
                        loss_dict_T.append(dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_G, loss_D, loss_D_T, t_scales_act = modelD.module.get_losses(loss_dict, loss_dict_T, t_scales)

                ###################################### Backward Pass #################################                 
                # update generator weights     
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights                
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):                    
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])

                if i == 0: fake_B_first = fake_B[0, 0]   # the first generated image in this sequence

            if opt.debug:
                call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k+str(s): v.data.item() if not isinstance(v, int) else v for k, v in loss_dict_T[s].items()})            
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:                
                visuals = util.save_all_tensors(opt, real_A, fake_B, fake_B_first, fake_B_raw, real_B, flow_ref, conf_ref, flow, weight, modelD)                
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD)            
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break
           
        # end of epoch 
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD, end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader) 
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    epoch_iter = 0

    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        visualizer.reset()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            save_result = total_steps % opt.update_html_freq == 0
            visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
Beispiel #29
0
class Trainer():
    def __init__(self, opt, data_loader):
        iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
        start_epoch, epoch_iter = 1, 0
        ### if continue training, recover previous states
        if opt.continue_train:
            if os.path.exists(iter_path):
                start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                     delimiter=',',
                                                     dtype=int)
            print('Resuming from epoch %d at iteration %d' %
                  (start_epoch, epoch_iter))

        print_freq = lcm(opt.print_freq, opt.batchSize)
        total_steps = (start_epoch - 1) * len(data_loader) + epoch_iter
        total_steps = total_steps // print_freq * print_freq

        self.opt = opt
        self.epoch_iter, self.print_freq, self.total_steps, self.iter_path = epoch_iter, print_freq, total_steps, iter_path
        self.start_epoch, self.epoch_iter = start_epoch, epoch_iter
        self.dataset_size = len(data_loader)
        self.visualizer = Visualizer(opt)

    def start_of_iter(self):
        if self.total_steps % self.print_freq == 0:
            self.iter_start_time = time.time()
        self.total_steps += self.opt.batchSize
        self.epoch_iter += self.opt.batchSize
        self.save = self.total_steps % self.opt.display_freq == 0

    def end_of_iter(self, loss_dicts, output_list, model):
        opt = self.opt
        epoch, epoch_iter, print_freq, total_steps = self.epoch, self.epoch_iter, self.print_freq, self.total_steps
        ############## Display results and errors ##########
        ### print out errors
        if is_master() and total_steps % print_freq == 0:
            t = (time.time() - self.iter_start_time) / print_freq
            errors = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in loss_dicts.items()
            }
            self.visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            self.visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if is_master() and self.save:
            visuals = save_all_tensors(opt, output_list, model)
            self.visualizer.display_current_results(visuals, epoch,
                                                    total_steps)

        if is_master() and opt.print_mem:
            call([
                "nvidia-smi", "--format=csv",
                "--query-gpu=memory.used,memory.free"
            ])

        ### save latest model
        save_models(opt, epoch, epoch_iter, total_steps, self.visualizer,
                    self.iter_path, model)
        if epoch_iter > self.dataset_size - opt.batchSize:
            return True
        return False

    def start_of_epoch(self, epoch, model, data_loader):
        self.epoch = epoch
        self.epoch_start_time = time.time()
        if self.opt.distributed:
            data_loader.dataloader.sampler.set_epoch(epoch)
        # update model params
        update_models(self.opt, epoch, model, data_loader)

    def end_of_epoch(self, model):
        opt = self.opt
        iter_end_time = time.time()
        self.visualizer.vis_print(
            opt, 'End of epoch %d / %d \t Time Taken: %d sec' %
            (self.epoch, opt.niter + opt.niter_decay,
             time.time() - self.epoch_start_time))

        ### save model for this epoch
        save_models(opt,
                    self.epoch,
                    self.epoch_iter,
                    self.total_steps,
                    self.visualizer,
                    self.iter_path,
                    model,
                    end_of_epoch=True)
        self.epoch_iter = 0
Beispiel #30
0
#################################################
print('step 3')
total_steps = 0
for epoch in range(1, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
	# You can use paired and unpaired data to train. Here we only use paired samples to train.
    for i,(images_a, images_b) in enumerate(dataset_paired):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter = total_steps - paired_dataset_size * (epoch - 1)
        model.set_input(images_a, images_b)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visuals = model.get_current_visuals()
            visualizer.display_current_results(visuals, epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, float(epoch_iter)/paired_dataset_size, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
Beispiel #31
0
def train(cfg):
    #init
    du.init_distributed_training(cfg)
    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    #init dataset
    dataset = create_dataset(cfg)  # create a dataset given cfg.dataset_mode and other options
    dataset_size = len(dataset)    # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(cfg)      # create a model given cfg.model and other options
    model.setup(cfg)               # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(cfg)   # create a visualizer that display/save images and plots
    total_iters = 0                # the total number of training iterations
    # cur_device = torch.cuda.current_device()
    is_master = du.is_master_proc(cfg.NUM_GPUS)
    for epoch in range(cfg.epoch_count, cfg.niter + cfg.niter_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        if is_master:
            epoch_start_time = time.time()  # timer for entire epoch
            iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        shuffle_dataset(dataset, epoch)
        for i, data in enumerate(dataset):  # inner loop within one epoch
            if is_master:
                iter_start_time = time.time()  # timer for computation per iteration
                if total_iters % cfg.print_freq == 0:
                    t_data = iter_start_time - iter_data_time
                    iter_data_time = time.time()
            visualizer.reset()
            total_iters += cfg.batch_size
            epoch_iter += cfg.batch_size
            model.set_input(data)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

            if total_iters % cfg.display_freq == 0 and is_master:   # display images on visdom and save images to a HTML file
                save_result = total_iters % cfg.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            losses = model.get_current_losses()
            if cfg.NUM_GPUS > 1:
                losses = du.all_reduce(losses)
            if total_iters % cfg.print_freq == 0 and is_master:    # print training losses and save logging information to the disk
                t_comp = (time.time() - iter_start_time) / cfg.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                if cfg.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
            if total_iters % cfg.save_latest_freq == 0 and is_master:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if cfg.save_by_iter else 'latest'
                model.save_networks(save_suffix)
                
        if epoch % cfg.save_epoch_freq == 0 and is_master:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            if cfg.save_iter_model and epoch>=30:
                model.save_networks(epoch)
        if is_master:
            print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, cfg.niter + cfg.niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()                     # update learning rates at the end of every epoch.
Beispiel #32
0
        #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if save_fake:
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated.data[0])),
                                   ('real_image', util.tensor2im(data['image'][0]))])
            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
       
    # end of epoch 
    iter_end_time = time.time()
    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

    ### save model for this epoch
    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))        
Beispiel #33
0
                t_data = iter_start_time - iter_data_time
            visualizer.reset()

            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            epoch_iter += 1
            total_iters += 1

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(
                    model.get_current_visuals(init_folder, epoch),
                    int(total_iters / opt.display_freq), opt.display_freq,
                    save_result)

#             if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
#                 losses = model.get_current_losses()
#                 t_comp = (time.time() - iter_start_time) / opt.batch_size
#                 visualizer.print_current_losses(epoch, total_iters, losses, t_comp, t_data)
#                 if opt.display_id > 0:
#                     visualizer.plot_current_losses(epoch, float(total_iters) / dataset_size, losses)

            if total_iters % opt.score_freq == 0:  # print generation scores and save logging information to the disk
                scores = model.get_current_scores()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_scores(epoch, total_iters, scores)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(