示例#1
0
def train(configuration):
    """
    the main training loop
    :param configuration: the config file
    :return:
    """
    image_path = configuration['image_path']
    print('using images from {}'.format(image_path))

    image_saving_path_mean = configuration['image_saving_path_mean']
    print('saving result images with mean loss to {}'.format(image_saving_path_mean))

    image_saving_path_mean_std = configuration['image_saving_path_mean_std']
    print('saving result images with mean + std loss to {}'.format(image_saving_path_mean_std))

    image_saving_path_mean_std_skew = configuration['image_saving_path_mean_std_skew']
    print('saving result images with mean + std + skew loss to {}'.format(image_saving_path_mean_std_skew))

    image_saving_path_mean_std_skew_kurtosis = configuration['image_saving_path_mean_std_skew_kurtosis']
    print('saving result images with mean + std + skew + kurtosis loss to {}'.format(image_saving_path_mean_std_skew_kurtosis))

    model_dir = configuration['model_dir']
    print('vgg-19 model dir is {}'.format(model_dir))

    vgg_model = get_vgg_model(configuration)
    print('got vgg model')

    number_style_images, style_image_file_paths = get_images(configuration)
    print('got {} style images'.format(number_style_images))

    all_images = []

    for i in range(1, 6):
        print('using style loss module: {}'.format(i))
        style_loss_module = get_loss_module(i)
        for j in range(number_style_images):
            print('computing image {} with style loss module {}'.format(j, i))
            style_image = load_image(style_image_file_paths[j])
            layer_images = [style_image.squeeze(0)]
            model, style_losses = get_full_style_model(configuration, vgg_model, style_image, style_loss_module)
            for k in range(4):
                print('computing loss at layer {}, image {}, loss module {}'.format(k, j, i))
                torch.manual_seed(13)
                image_noise = torch.randn(style_image.data.size()).to(device)
                layer_images += [train_full_style_model(model, style_losses, image_noise, k, i * (k+1) * 20000)
                                 .squeeze(0)]
            print('saving images at the different layers')
            save_layer_images(configuration, layer_images, i, j)
            all_images += [layer_images]
    print('saving the side-by-side comparisons')
    save_all_images(configuration, all_images, number_style_images)
示例#2
0
def train_mmd(configuration):
    """
    training loop utilizing the MMD loss
    :param configuration: the config file
    :return:
    """
    image_path = configuration['image_path']
    print('using images from {}'.format(image_path))

    image_saving_path_mmd = configuration['image_saving_path_mmd']
    print('saving result images with mmd loss to {}'.format(image_saving_path_mmd))

    model_dir = configuration['model_dir']
    print('vgg-19 model dir is {}'.format(model_dir))

    vgg_model = get_vgg_model(configuration)
    print('got vgg model')

    number_style_images, style_image_file_paths = get_images(configuration)
    print('got {} style images'.format(number_style_images))

    i = 6

    print('using mmd loss module: {}'.format(i))
    style_loss_module = get_loss_module(i)
    for j in range(number_style_images):
        print('computing image {} with style loss module {}'.format(j, i))
        style_image = load_image(style_image_file_paths[j])
        layer_images = [style_image.squeeze(0)]
        model, style_losses = get_full_style_model(configuration, vgg_model, style_image, style_loss_module)
        for k in range(4):
            print('computing loss at layer {}, image {}, loss module {}'.format(k, j, i))
            torch.manual_seed(13)
            image_noise = torch.randn(style_image.data.size()).to(device)
            steps = (k+2) * 50000
            style_weight = 1000
            layer_images += [train_full_style_model(model, style_losses, image_noise, k,
                                                    steps, style_weight, early_stopping=True).squeeze(0)]
            # for debug
            save_layer_images(configuration, layer_images, i, j)
        print('saving images at the different layers')
        save_layer_images(configuration, layer_images, i, j)
    print('finished')
示例#3
0
def train_gram(configuration):
    """
    training loop utilizing the Gram matrix loss
    :param configuration: the config file
    :return:
    """
    image_path = configuration['image_path']
    print('using images from {}'.format(image_path))

    image_saving_path_gram = configuration['image_saving_path_gram']
    print('saving result images with gram loss to {}'.format(image_saving_path_gram))

    model_dir = configuration['model_dir']
    print('vgg-19 model dir is {}'.format(model_dir))

    vgg_model = get_vgg_model(configuration)
    print('got vgg model')

    number_style_images, style_image_file_paths = get_images(configuration)
    print('got {} style images'.format(number_style_images))

    all_images = []
    i = 5

    print('using gram loss module: {}'.format(i))
    style_loss_module = get_loss_module(i)
    for j in range(number_style_images):
        print('computing image {} with style loss module {}'.format(j, i))
        style_image = load_image(style_image_file_paths[j])
        layer_images = [style_image.squeeze(0)]
        model, style_losses = get_full_style_model(configuration, vgg_model, style_image, style_loss_module)
        for k in range(4):
            print('computing loss at layer {}, image {}, loss module {}'.format(k, j, i))
            torch.manual_seed(13)
            image_noise = torch.randn(style_image.data.size()).to(device)
            steps = i * (k+1) * 20000
            layer_images += [train_full_style_model(model, style_losses, image_noise, k, steps).squeeze(0)]
        print('saving images at the different layers')
        save_layer_images(configuration, layer_images, i, j)
        all_images += [layer_images]
    print('finished')
示例#4
0
def train(configuration):
    """
    this is the main training loop
    :param configuration: the config
    :return:
    """
    image_saving_path = configuration['image_saving_path']
    print('saving result images to {}'.format(image_saving_path))

    model_dir = configuration['model_dir']
    print('vgg-19 model dir is {}'.format(model_dir))

    vgg_model = get_vgg_model(configuration)
    print('got vgg model')

    style_image_path = configuration['style_image_path']
    number_style_images, style_image_file_paths = get_images(style_image_path)
    print('got {} style images'.format(number_style_images))
    print('using the style images from path: {}'.format(style_image_path))

    content_image_path = configuration['content_image_path']
    number_content_images, content_image_file_paths = get_images(
        content_image_path)
    print('got {} content images'.format(number_content_images))
    print('using the content images from path: {}'.format(content_image_path))

    steps = configuration['steps']
    print('training for {} steps'.format(steps))

    content_weight = configuration['content_weight']
    style_weight = configuration['style_weight']
    print('content weight: {}, style weight: {}'.format(
        content_weight, style_weight))

    lr = configuration['lr']
    print('using a learning rate of {}'.format(lr))

    for i in range(number_style_images):
        print('style image {}'.format(i))
        for j in range(number_content_images):
            images = []
            print('content image {}'.format(j))
            style_image = load_image(style_image_file_paths[i])
            content_image = load_image(content_image_file_paths[j])

            images += [style_image.squeeze(0).cpu()]
            print('got style image')
            images += [content_image.squeeze(0).cpu()]
            print('got content image')

            for k in range(1, 6):
                print('training transfer image with loss {}'.format(k))

                loss_writer = LossWriter(
                    os.path.join(
                        configuration['folder_structure'].get_parent_folder(),
                        './loss/loss'))
                loss_writer.write_header(columns=[
                    'iteration', f'style_loss_{k}', f'content_loss_{k}',
                    f'loss_{k}'
                ])

                torch.manual_seed(1)
                image_noise = torch.randn(style_image.data.size()).to(device)
                model, style_losses, content_losses = get_full_style_model(
                    configuration, vgg_model, style_image, content_image,
                    get_style_loss_module(k), get_content_loss_module())

                # this is to align the loss magnitudes of Gram matrix loss and moment loss
                if k == 1:
                    style_weight *= 100

                img = train_neural_style_transfer(
                    model, lr, style_losses, content_losses, image_noise,
                    steps, style_weight, content_weight,
                    loss_writer).squeeze(0).cpu()

                images += [img.clone()]

                save_single_image(configuration, img, -k, -k)

                print('got transfer image')

            save_image(configuration, images, i, j)
示例#5
0
def train_mmd(configuration):
    """
    this is the MMD training loop
    :param configuration: the config
    :return:
    """
    image_saving_path = configuration['image_saving_path']
    print('saving result images to {}'.format(image_saving_path))

    model_dir = configuration['model_dir']
    print('vgg-19 model dir is {}'.format(model_dir))

    vgg_model = get_vgg_model(configuration)
    print('got vgg model')

    style_image_path = configuration['style_image_path']
    number_style_images, style_image_file_paths = get_images(style_image_path)
    print('got {} style images'.format(number_style_images))
    print('using the style images from path: {}'.format(style_image_path))

    content_image_path = configuration['content_image_path']
    number_content_images, content_image_file_paths = get_images(
        content_image_path)
    print('got {} content images'.format(number_content_images))
    print('using the content images from path: {}'.format(content_image_path))

    loss_writer = LossWriter(
        os.path.join(configuration['folder_structure'].get_parent_folder(),
                     './loss/loss'))
    loss_writer.write_header(
        columns=['iteration', 'style_loss', 'content_loss', 'loss'])

    print(style_image_file_paths)
    print(content_image_file_paths)

    images = []

    for i in range(number_style_images):
        print('style image {}'.format(i))
        for j in range(number_content_images):
            style_image = load_image(style_image_file_paths[i])
            content_image = load_image(content_image_file_paths[j])

            images += [style_image.squeeze(0).cpu()]
            print('got style image')
            images += [content_image.squeeze(0).cpu()]
            print('got content image')

            print('training transfer image with loss {} (MMD loss)'.format(2))
            torch.manual_seed(1)
            image_noise = torch.randn(style_image.data.size()).to(device)
            model, style_losses, content_losses = get_full_style_model(
                configuration, vgg_model, style_image, content_image,
                get_style_loss_module(2), get_content_loss_module())

            steps = configuration['steps']
            print('training for {} steps'.format(steps))

            content_weight = configuration['content_weight']
            style_weight = configuration['style_weight']
            print('content weight: {}, style weight: {}'.format(
                content_weight, style_weight))

            lr = configuration['lr']
            print('learning rate: {}'.format(lr))

            img = train_neural_style_transfer(model, lr, style_losses,
                                              content_losses, image_noise,
                                              steps, style_weight,
                                              content_weight,
                                              loss_writer).squeeze(0).cpu()

            save_image(configuration, img, j, i)

            print('got transfer image')