Example #1
0
model_path_list = [
    prefix + 'moment_alignment_model_0_1_2_0' + suffix,
    prefix + 'moment_alignment_model_0_1_3_0' + suffix,
    prefix + 'moment_alignment_model_0_1_4_0' + suffix,
    prefix + 'moment_alignment_model_0_1_5_0' + suffix,
]

model_configuration = {
    'do_incremental_training': False,
    'use_very_shallow_model_list': [False for i in range(len(model_path_list))],
    'load_model': False
}

model_list = [None for _ in range(len(model_path_list))]
for i in range(len(model_list)):
    moment_alignment_model = net.get_moment_alignment_model(model_configuration, moment_mode=i + 2, use_list=True, list_index=i)
    checkpoint = torch.load(model_path_list[i], map_location='cpu')
    moment_alignment_model.load_state_dict(checkpoint)
    model_list[i] = moment_alignment_model.to(device)

prefix = '/Users/Johannes/Desktop/encoder_decoder_exp/' if local else prefix
suffix = '.pth'
encoder_decoder_model_paths = {
    'encoder_model_path': prefix + 'encoder_1_25_6_state_dict' + suffix,
    'decoder_model_path': prefix + 'decoder_1_25_6_state_dict' + suffix,
}

encoder = net.get_trained_encoder(encoder_decoder_model_paths)
decoder = net.get_trained_encoder(encoder_decoder_model_paths)

encoder.to(device)
def test(configuration):
    """
    test the moment alignment solution (for 2 moments) in comparison
    to the analytical solution that can be computed for mean and std
    :param configuration: the config file
    :return:
    """
    encoder = net.get_trained_encoder(configuration)
    decoder = net.get_trained_decoder(configuration)

    analytical_ada_in_module = net.AdaptiveInstanceNormalization()

    pretrained_model_path = configuration['pretrained_model_path']
    content_images_path = configuration['content_images_path']
    style_images_path = configuration['style_images_path']
    loader = configuration['loader']
    unloader = configuration['unloader']
    image_test_saving_path = configuration['image_saving_path']
    moment_mode = configuration['moment_mode']

    print('loading the moment alignment model from {}'.format(
        pretrained_model_path))
    moment_alignment_model = net.get_moment_alignment_model(
        configuration, moment_mode)
    print(moment_alignment_model)

    checkpoint = torch.load(pretrained_model_path, map_location=device)
    moment_alignment_model.load_state_dict(checkpoint)

    number_content_images = len(os.listdir(content_images_path))
    number_style_images = len(os.listdir(style_images_path))
    content_image_files = [
        '{}/{}'.format(content_images_path,
                       os.listdir(content_images_path)[i])
        for i in range(number_content_images)
    ]
    style_image_files = [
        '{}/{}'.format(style_images_path,
                       os.listdir(style_images_path)[i])
        for i in range(number_style_images)
    ]

    for i in range(number_style_images):
        print("test_image {} at {}".format(i + 1, style_image_files[i]))

    for i in range(number_content_images):
        print("test_image {} at {}".format(i + 1, content_image_files[i]))

    iterations = 0
    mean_percentages = [0, 0, 0, 0, 0, 0, 0]

    for j in range(number_content_images):
        for i in range(number_style_images):
            style_image = data_loader.image_loader(style_image_files[i],
                                                   loader)
            content_image = data_loader.image_loader(content_image_files[j],
                                                     loader)
            with torch.no_grad():
                content_feature_maps = encoder(content_image)['r41']
                style_feature_maps = encoder(style_image)['r41']

                content_feature_map_batch_loader = data_loader.get_batch(
                    content_feature_maps, 512)
                style_feature_map_batch_loader = data_loader.get_batch(
                    style_feature_maps, 512)

                content_feature_map_batch = next(
                    content_feature_map_batch_loader).to(device)
                style_feature_map_batch = next(
                    style_feature_map_batch_loader).to(device)

                style_feature_map_batch_moments = u.compute_moments_batches(
                    style_feature_map_batch)
                content_feature_map_batch_moments = u.compute_moments_batches(
                    content_feature_map_batch)

                out = moment_alignment_model(
                    content_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments)

                result_feature_maps = out

                analytical_feature_maps = analytical_ada_in_module(
                    content_feature_map_batch, style_feature_map_batch)

                a_0, a_001, a_001_l, a_01, a_01_l, a_1, a_1_l = \
                    get_distance(analytical_feature_maps, result_feature_maps)
                iterations += 1

                mean_percentages[0] += a_0
                mean_percentages[1] += a_001
                mean_percentages[2] += a_001_l
                mean_percentages[3] += a_01
                mean_percentages[4] += a_01_l
                mean_percentages[5] += a_1
                mean_percentages[6] += a_1_l

                # u.imshow(decoder(analytical_feature_maps.view(1, 512, 32, 32)), transforms.ToPILImage())

                utils.save_image([
                    data_loader.imnorm(content_image, unloader),
                    data_loader.imnorm(style_image, unloader),
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/A_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

                utils.save_image([
                    data_loader.imnorm(content_image, unloader),
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/B_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

                utils.save_image([
                    data_loader.imnorm(style_image, unloader),
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/C_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

                utils.save_image([
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/D_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

    print('averaging percentages')
    mean_percentages = [
        mean_percentages[i] / iterations for i in range(len(mean_percentages))
    ]
    print(mean_percentages)
Example #3
0
def test(configuration):
    """
    test loop to produce images with multiple models
    :param configuration:
    :return:
    """
    encoder = net.get_trained_encoder(configuration)
    decoder = net.get_trained_decoder(configuration)

    model_path_list = configuration['model_path_list']
    content_images_path = configuration['content_images_path']
    style_images_path = configuration['style_images_path']
    loader = configuration['loader']
    image_saving_path = configuration['image_saving_path']
    mode_list = configuration['mode_list']

    model_list = [0 for _ in range(len(model_path_list))]
    for i in range(len(model_list)):
        moment_alignment_model = net.get_moment_alignment_model(
            configuration,
            moment_mode=mode_list[i],
            use_list=True,
            list_index=i)
        checkpoint = torch.load(model_path_list[i], map_location='cpu')
        moment_alignment_model.load_state_dict(checkpoint)
        model_list[i] = moment_alignment_model

    number_content_images = len(os.listdir(content_images_path))
    number_style_images = len(os.listdir(style_images_path))
    content_image_files = [
        '{}/{}'.format(content_images_path,
                       sorted(os.listdir(content_images_path))[i])
        for i in range(number_content_images)
    ]
    style_image_files = [
        '{}/{}'.format(style_images_path,
                       sorted(os.listdir(style_images_path))[i])
        for i in range(number_style_images)
    ]

    for i in range(number_content_images):
        for j in range(number_style_images):
            print('at image {}'.format(i))
            with torch.no_grad():
                content_image = data_loader.image_loader(
                    content_image_files[i], loader)
                style_image = data_loader.image_loader(style_image_files[j],
                                                       loader)

                result_images = [0 for _ in range(len(model_list))]
                for k in range(len(model_list)):
                    content_feature_map_batch_loader = data_loader.get_batch(
                        encoder(content_image)['r41'], 512)
                    style_feature_map_batch_loader = data_loader.get_batch(
                        encoder(style_image)['r41'], 512)
                    content_feature_map_batch = next(
                        content_feature_map_batch_loader).to(device)
                    style_feature_map_batch = next(
                        style_feature_map_batch_loader).to(device)

                    style_feature_map_batch_moments = utils.compute_moments_batches(
                        style_feature_map_batch, last_moment=7)
                    content_feature_map_batch_moments = utils.compute_moments_batches(
                        content_feature_map_batch, last_moment=7)

                    result_images[k] = decoder(model_list[k](
                        content_feature_map_batch,
                        content_feature_map_batch_moments,
                        style_feature_map_batch_moments).view(1, 512, 32, 32))
                    result_images[k] = result_images[k].squeeze(0)

                # save all images in one row
                u.save_image(
                    [
                        data_loader.imnorm(content_image, None),
                        data_loader.imnorm(style_image, None)
                    ] + result_images,
                    '{}/moment_alignment_test_image_A_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)

                # save all images in two rows
                u.save_image(
                    [
                        data_loader.imnorm(content_image, None),
                        data_loader.imnorm(style_image, None)
                    ] + [torch.ones(3, 256, 256)
                         for _ in range(2)] + result_images,
                    '{}/moment_alignment_test_image_B_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1,
                    nrow=4)

                # save all result images in one row
                u.save_image(
                    result_images,
                    '{}/moment_alignment_test_image_C_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)

                # save all result images in one row + content image
                u.save_image(
                    [data_loader.imnorm(content_image, None)] + result_images,
                    '{}/moment_alignment_test_image_D_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)

                # save all result images in one row + style image
                u.save_image(
                    [data_loader.imnorm(style_image, None)] + result_images,
                    '{}/moment_alignment_test_image_E_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)
Example #4
0
def test(configuration):
    """
    test loop
    :param configuration: the config file
    :return:
    """
    analytical_ada_in_module = net.AdaptiveInstanceNormalization()
    encoder = net.get_trained_encoder(configuration)
    decoder = net.get_trained_decoder(configuration)

    pretrained_model_path = configuration['pretrained_model_path']
    print('loading the moment alignment model from {}'.format(pretrained_model_path))

    content_images_path = configuration['content_images_path']
    style_images_path = configuration['style_images_path']

    loader = configuration['loader']
    unloader = configuration['unloader']

    image_saving_path = configuration['image_saving_path']
    moment_mode = configuration['moment_mode']

    moment_alignment_model = net.get_moment_alignment_model(configuration, moment_mode)
    print(moment_alignment_model)

    checkpoint = torch.load(pretrained_model_path, map_location=device)
    moment_alignment_model.load_state_dict(checkpoint)

    aligned_moment_loss = net.get_loss(configuration, moment_mode=moment_mode, lambda_1=0, lambda_2=10)

    number_content_images = len(os.listdir(content_images_path))
    number_style_images = len(os.listdir(style_images_path))

    content_image_files = ['{}/{}'.format(content_images_path, os.listdir(content_images_path)[i])
                           for i in range(number_content_images)]
    style_image_files = ['{}/{}'.format(style_images_path, os.listdir(style_images_path)[i])
                         for i in range(number_style_images)]

    for i in range(number_style_images):
        print("test_image {} at {}".format(i + 1, style_image_files[i]))

    for i in range(number_content_images):
        print("test_image {} at {}".format(i + 1, content_image_files[i]))

    for j in range(number_content_images):
        for i in range(number_style_images):
            style_image = data_loader.image_loader(style_image_files[i], loader)
            content_image = data_loader.image_loader(content_image_files[j], loader)
            with torch.no_grad():
                content_feature_maps = encoder(content_image)['r41']
                style_feature_maps = encoder(style_image)['r41']

                content_feature_map_batch_loader = data_loader.get_batch(content_feature_maps, 512)
                style_feature_map_batch_loader = data_loader.get_batch(style_feature_maps, 512)

                content_feature_map_batch = next(content_feature_map_batch_loader).to(device)
                style_feature_map_batch = next(style_feature_map_batch_loader).to(device)

                if use_MA_module:
                    style_feature_map_batch_moments = u.compute_moments_batches(style_feature_map_batch, last_moment=7)
                    content_feature_map_batch_moments = u.compute_moments_batches(content_feature_map_batch, last_moment=7)

                    out = moment_alignment_model(content_feature_map_batch,
                                                 content_feature_map_batch_moments,
                                                 style_feature_map_batch_moments,
                                                 is_test=True)

                    out_feature_map_batch_moments = u.compute_moments_batches(out, last_moment=7)

                    print_some_moments(style_feature_map_batch_moments, content_feature_map_batch_moments, out_feature_map_batch_moments)

                    loss, moment_loss, reconstruction_loss = aligned_moment_loss(content_feature_map_batch,
                                                                                 style_feature_map_batch,
                                                                                 content_feature_map_batch_moments,
                                                                                 style_feature_map_batch_moments,
                                                                                 out,
                                                                                 is_test=True)

                    print('loss: {}, moment_loss: {}, reconstruction_loss:{}'.format(
                        loss.item(), moment_loss.item(), reconstruction_loss.item()))
                else:
                    analytical_feature_maps = analytical_ada_in_module(content_feature_map_batch, style_feature_map_batch)
                    out = analytical_feature_maps

                utils.save_image([data_loader.imnorm(content_image, unloader),
                                  data_loader.imnorm(style_image, unloader),
                                  data_loader.imnorm(decoder(out.view(1, 512, 32, 32)), None)],
                                  '{}/A_image_{}_{}.jpeg'.format(image_saving_path, i,j), normalize=False)

                utils.save_image([data_loader.imnorm(decoder(out.view(1, 512, 32, 32)), None)],
                                 '{}/B_image_{}_{}.jpeg'.format(image_saving_path, i, j), normalize=False)
Example #5
0
def train(configuration):
    """
    this is the main training loop
    :param configuration: the config file
    :return:
    """
    epochs = configuration['epochs']
    print('going to train for {} epochs'.format(epochs))

    step_printing_interval = configuration['step_printing_interval']
    print('writing to console every {} steps'.format(step_printing_interval))

    image_saving_interval = configuration['image_saving_interval']
    print('writing to console every {} steps'.format(step_printing_interval))

    epoch_saving_interval = configuration['epoch_saving_interval']
    print('saving the model every {} epochs'.format(epoch_saving_interval))

    validation_interval = configuration['validation_interval']
    print('validating the model every {} epochs'.format(validation_interval))

    image_saving_path = configuration['image_saving_path']
    print('saving images to {}'.format(image_saving_path))

    loader = configuration['loader']

    model_saving_path = configuration['model_saving_path']
    print('saving models to {}'.format(model_saving_path))

    # tensorboardX_path = configuration['tensorboardX_path']
    # writer = SummaryWriter(logdir='{}/runs'.format(tensorboardX_path))
    # print('saving tensorboardX logs to {}'.format(tensorboardX_path))

    loss_writer = LossWriter(os.path.join(
        configuration['folder_structure'].get_parent_folder(), './loss/loss'),
                             buffer_size=100)
    loss_writer.write_header(columns=[
        'epoch', 'all_training_iteration', 'loss', 'moment_loss',
        'reconstruction_loss'
    ])

    validation_loss_writer = LossWriter(
        os.path.join(configuration['folder_structure'].get_parent_folder(),
                     './loss/loss'))
    validation_loss_writer.write_header(columns=[
        'validation_iteration', 'loss', 'moment_loss', 'reconstruction_loss'
    ])

    # batch_size is the number of images to sample
    batch_size = 1
    feature_map_batch_size = int(configuration['feature_map_batch_size'])
    print('training in batches of {} feature maps'.format(
        feature_map_batch_size))

    coco_data_path_train = configuration['coco_data_path_train']
    painter_by_numbers_data_path_train = configuration[
        'painter_by_numbers_data_path_train']
    print('using {} and {} for training'.format(
        coco_data_path_train, painter_by_numbers_data_path_train))

    coco_data_path_val = configuration['coco_data_path_val']
    painter_by_numbers_data_path_val = configuration[
        'painter_by_numbers_data_path_val']
    print('using {} and {} for validation'.format(
        coco_data_path_val, painter_by_numbers_data_path_val))

    train_dataloader = data_loader.get_concat_dataloader(
        coco_data_path_train,
        painter_by_numbers_data_path_train,
        batch_size,
        loader=loader)
    print('got train dataloader')

    val_dataloader = data_loader.get_concat_dataloader(
        coco_data_path_val,
        painter_by_numbers_data_path_val,
        batch_size,
        loader=loader)
    print('got val dataloader')

    lambda_1 = configuration['lambda_1']
    print('lambda 1: {}'.format(lambda_1))

    lambda_2 = configuration['lambda_2']
    print('lambda 2: {}'.format(lambda_2))

    loss_moment_mode = configuration['moment_mode']
    net_moment_mode = configuration['moment_mode']
    print('loss is sum of the first {} moments'.format(loss_moment_mode))
    print('net accepts {} in-channels'.format(net_moment_mode))

    unloader = configuration['unloader']
    print('got the unloader')

    do_validation = configuration['do_validation']
    print('doing validation: {}'.format(do_validation))

    moment_alignment_model = net.get_moment_alignment_model(
        configuration, moment_mode=loss_moment_mode)
    print('got model')
    print(moment_alignment_model)

    decoder = net.get_trained_decoder(configuration)
    print('got decoder')
    print(decoder)

    print('params that require grad')
    for name, param in moment_alignment_model.named_parameters():
        if param.requires_grad:
            print(name)

    criterion = net.get_loss(configuration,
                             moment_mode=loss_moment_mode,
                             lambda_1=lambda_1,
                             lambda_2=lambda_2)

    print('got moment loss module')
    print(criterion)

    encoder = net.get_trained_encoder(configuration)
    print('got encoder')
    print(encoder)

    try:
        optimizer = optim.Adam(moment_alignment_model.parameters(),
                               lr=configuration['lr'])
    except:
        optimizer = optim.Adam(moment_alignment_model.module.parameters(),
                               lr=configuration['lr'])
    print('got optimizer')
    schedule = QuadraticSchedule(timesteps=10000000,
                                 initial=configuration['lr'],
                                 final=configuration['lr'] / 10.)
    print('got schedule')

    print('making iterable from train dataloader')
    train_data_loader = iter(train_dataloader)
    print('train data loader iterable')
    outer_training_iteration = -1
    all_training_iteration = 0

    number_of_validation = 0
    current_validation_loss = float('inf')

    for epoch in range(1, epochs):
        print('epoch: {}'.format(epoch))

        # this is the outer training loop (sampling images)
        print('training model ...')
        while True:
            try:
                data = train_data_loader.__next__()
                outer_training_iteration += 1
            except StopIteration:
                print('got to the end of the dataloader (StopIteration)')
                train_data_loader = iter(train_dataloader)
                break
            except:
                print('something went wrong with the dataloader, continuing')
                continue

            if do_validation:
                # validate the model every validation_interval iterations
                if outer_training_iteration % validation_interval == 0:
                    print('making iterable from val dataloader')
                    val_data_loader = iter(val_dataloader)
                    print('val data loader iterable')
                    validation_loss = validate(number_of_validation, criterion,
                                               encoder, moment_alignment_model,
                                               val_data_loader,
                                               feature_map_batch_size,
                                               validation_loss_writer)
                    number_of_validation += 1
                    if validation_loss < current_validation_loss:
                        utils.save_current_best_model(
                            epoch, moment_alignment_model,
                            configuration['model_saving_path'])
                        print('got a better model')
                        current_validation_loss = validation_loss
                        print('set the new validation loss to the current one')
                    else:
                        print('this model is actually worse than the best one')

            # get the content_image batch
            content_image = data.get('coco').get('image')
            content_image = content_image.to(device)

            # get the style_image batch
            style_image = data.get('painter_by_numbers').get('image')
            style_image = style_image.to(device)

            style_feature_maps = encoder(style_image)['r41'].to(device)
            content_feature_maps = encoder(content_image)['r41'].to(device)

            result_feature_maps = torch.zeros(1, 1, 32, 32)

            content_feature_map_batch_loader = data_loader.get_batch(
                content_feature_maps, feature_map_batch_size)
            style_feature_map_batch_loader = data_loader.get_batch(
                style_feature_maps, feature_map_batch_size)

            # this is the inner training loop (feature maps)
            while True:
                try:
                    content_feature_map_batch = next(
                        content_feature_map_batch_loader).to(device)
                    style_feature_map_batch = next(
                        style_feature_map_batch_loader).to(device)
                    all_training_iteration += 1
                except StopIteration:
                    break
                except:
                    continue

                do_print = all_training_iteration % step_printing_interval == 0

                optimizer.zero_grad()

                style_feature_map_batch_moments = utils.compute_moments_batches(
                    style_feature_map_batch, last_moment=net_moment_mode)
                content_feature_map_batch_moments = utils.compute_moments_batches(
                    content_feature_map_batch, last_moment=net_moment_mode)

                out = moment_alignment_model(
                    content_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments)

                loss, moment_loss, reconstruction_loss = criterion(
                    content_feature_map_batch,
                    style_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments,
                    out,
                    last_moment=loss_moment_mode)

                if do_print:
                    set_lr(optimizer,
                           lr=schedule.get(all_training_iteration /
                                           step_printing_interval))
                    loss_writer.write_row([
                        epoch, all_training_iteration,
                        loss.item(),
                        moment_loss.item(),
                        reconstruction_loss.item()
                    ])

                # backprop
                loss.backward()
                optimizer.step()

                result_feature_maps = torch.cat(
                    [result_feature_maps,
                     out.cpu().view(1, -1, 32, 32)], 1)

                # if do_print:
                #     print('loss: {:4f}'.format(loss.item()))
                #
                #     writer.add_scalar('data/training_loss', loss.item(), all_training_iteration)
                #     writer.add_scalar('data/training_moment_loss', moment_loss.item(), all_training_iteration)
                #     writer.add_scalar('data/training_reconstruction_loss', reconstruction_loss.item(),
                #                       all_training_iteration)

            result_feature_maps = result_feature_maps[:, 1:513, :, :]
            result_img = decoder(result_feature_maps.to(device))

            if outer_training_iteration % image_saving_interval == 0:
                u.save_image([
                    data_loader.imnorm(content_image, unloader),
                    data_loader.imnorm(style_image, unloader),
                    data_loader.imnorm(result_img, None)
                ],
                             '{}/image_{}_{}__{}_{}.jpeg'.format(
                                 image_saving_path, epoch,
                                 outer_training_iteration / epoch, lambda_1,
                                 lambda_2),
                             normalize=False)

            # save every epoch_saving_interval the current model
            if outer_training_iteration % image_saving_interval == 0:
                utils.save_current_model(lambda_1, lambda_2,
                                         moment_alignment_model.state_dict(),
                                         optimizer.state_dict(),
                                         configuration['model_saving_path'])