Exemplo n.º 1
def prepare_dataset(input_dir, output_dir, size, images_per_file, file_prefix):
    assert os.path.exists(input_dir), 'Input directory does not exist'
    assert os.path.isdir(input_dir), '%s is not a directory' % input_dir
    assert os.path.exists(output_dir), 'Output directory does not exist'
    assert os.path.isdir(output_dir), '%s is not a directory' % output_dir

    filenames = extract_image_names_recursive(input_dir)
    print("%s images found in %s" % (len(filenames), input_dir))

    start = time.time()
    errors = 0
    rate = 0
    update_stat_every = 100
    writer = None
    for i, filename in enumerate(filenames):
        if i % images_per_file == 0:  # roll to a new file
            if writer:
            output_file = '%s-%04i.tfrecords' % (file_prefix,
                                                 i // images_per_file)
            output_path = os.path.join(output_dir, output_file)
            writer = tf.python_io.TFRecordWriter(output_path)
            image = load_image(filename, size)
            image = prepare_image(image, normalize=False)
            example = build_example(image)
        except (OSError, OverflowError, ValueError):
            errors += 1
        print(i, '\t', '%0.4f image/sec, %s errors' % (rate, errors), end='\r')
        if i % update_stat_every == 0:
            rate = i / (time.time() - start)

    print('%s images processed at %0.4f image/sec. %s errors occurred.' %
          (len(filenames), rate, errors))
Exemplo n.º 2
def style_transfer(
    assert bool(content) != bool(content_dir), 'Either content or content_dir should be given'
    assert bool(style) != bool(style_dir), 'Either style or style_dir should be given'

    if not os.path.exists(output_dir):
        print('Creating output dir at', output_dir)

    # Assume that it is either an h5 file or a name of a TensorFlow checkpoint
    # NOTE: For now, artificially switching off pretrained h5 weights
    decoder_in_h5 = False
    # decoder_weights.endswith('.h5')

    if content:
        content_batch = [content]
        assert mask is None, 'For spatial control use the --content option'
        content_batch = extract_image_names_recursive(content_dir)

    if style:
        style = style.split(',')
        if mask:
            assert len(style) == 2, 'For spatial control provide two style images'
            style_batch = [style]
        elif len(style) > 1: # Style blending
            if not style_interp_weights:
                # by default, all styles get equal weights
                style_interp_weights = np.array([1.0/len(style)] * len(style))
                # normalize weights so that their sum equals to one
                style_interp_weights = [float(w) for w in style_interp_weights.split(',')]
                style_interp_weights = np.array(style_interp_weights)
                style_interp_weights /= np.sum(style_interp_weights)
                assert len(style) == len(style_interp_weights), """--style and --style_interp_weights must have the same number of elements"""
            style_batch = [style]
            style_batch = style
        assert mask is None, 'For spatial control use the --style option'
        style_batch = extract_image_names_recursive(style_dir)

    print('Number of content images:', len(content_batch))
    print('Number of style images:', len(style_batch))

    if gpu >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
        data_format = 'channels_first'
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        data_format = 'channels_last'

    image, content, style, target, encoder, decoder, otherLayers = _build_graph(vgg_weights,
        decoder_weights if decoder_in_h5 else None, alpha, data_format=data_format)

    with tf.Session() as sess:
        if decoder_in_h5:
        elif tf_checkpoint_dir is not None: # Some checkpoint was provided
            saver = tf.train.Saver()
            saver.restore(sess, os.path.join(tf_checkpoint_dir, 'adain-final'))

        for content_path, style_path in product(content_batch, style_batch):
            content_name = get_filename(content_path)
            content_image = load_image(content_path, content_size, crop=True)

            if isinstance(style_path, list): # Style blending/Spatial control
                style_paths = style_path
                style_name = '_'.join(map(get_filename, style_paths))

                # Gather all style images in one numpy array in order to get
                # their activations in one pass
                style_images = None
                for i, style_path in enumerate(style_paths):
                    style_image = load_image(style_path, style_size, crop)
                    if preserve_color:
                        style_image = coral(style_image, content_image)
                    style_image = prepare_image(style_image)
                    if style_images is None:
                        shape = tuple([len(style_paths)]) + style_image.shape
                        style_images = np.empty(shape)
                    assert style_images.shape[1:] == style_image.shape, """Style images must have the same shape"""
                    style_images[i] = style_image
                style_features = sess.run(encoder, feed_dict={
                    image: style_images

                content_image = prepare_image(content_image)
                content_feature = sess.run(encoder, feed_dict={
                    image: content_image[np.newaxis,:]

                if mask:
                    # For spatial control, extract foreground and background
                    # parts of the content using the corresponding masks,
                    # run them individually through AdaIN then combine
                    if data_format == 'channels_first':
                        _, c, h, w = content_feature.shape
                        content_view_shape = (c, -1)
                        mask_shape = lambda mask: (c, len(mask), 1)
                        mask_slice = lambda mask: (slice(None),mask)
                        _, h, w, c = content_feature.shape
                        content_view_shape = (-1, c)
                        mask_shape = lambda mask: (1, len(mask), c)
                        mask_slice = lambda mask: (mask,slice(None))

                    mask = load_mask(mask, h, w).reshape(-1)
                    fg_mask = np.flatnonzero(mask == 1)
                    bg_mask = np.flatnonzero(mask == 0)

                    content_feat_view = content_feature.reshape(content_view_shape)
                    content_feat_fg = content_feat_view[mask_slice(fg_mask)].reshape(mask_shape(fg_mask))
                    content_feat_bg = content_feat_view[mask_slice(bg_mask)].reshape(mask_shape(bg_mask))

                    style_feature_fg = style_features[0]
                    style_feature_bg = style_features[1]

                    target_feature_fg = sess.run(target, feed_dict={
                        content: content_feat_fg[np.newaxis,:],
                        style: style_feature_fg[np.newaxis,:]
                    target_feature_fg = np.squeeze(target_feature_fg)

                    target_feature_bg = sess.run(target, feed_dict={
                        content: content_feat_bg[np.newaxis,:],
                        style: style_feature_bg[np.newaxis,:]
                    target_feature_bg = np.squeeze(target_feature_bg)

                    target_feature = np.zeros_like(content_feat_view)
                    target_feature[mask_slice(fg_mask)] = target_feature_fg
                    target_feature[mask_slice(bg_mask)] = target_feature_bg
                    target_feature = target_feature.reshape(content_feature.shape)
                    # For style blending, get activations for each style then
                    # take a weighted sum.
                    target_feature = np.zeros(content_feature.shape)
                    for style_feature, weight in zip(style_features, style_interp_weights):
                        target_feature += sess.run(target, feed_dict={
                            content: content_feature,
                            style: style_feature[np.newaxis,:]
                        }) * weight
                # NOTE: This is the part we care about, if only 1 style image is provided.
                style_name = get_filename(style_path)
                style_image = load_image(style_path, style_size, crop=True) # This only gives us square crop
                style_image = center_crop_np(style_image) # Actually crop the center out

                if preserve_color:
                    style_image = coral(style_image, content_image)
                style_image = prepare_image(style_image, True, data_format)
                content_image = prepare_image(content_image, True, data_format)
                # Extract other layers
                conv3_1_layer, conv4_1_layer = otherLayers
                style_feature, conv3_1_out_style, conv4_1_out_style = sess.run([encoder, conv3_1_layer, conv4_1_layer], feed_dict={
                    image: style_image[np.newaxis,:]

                content_feature = sess.run(encoder, feed_dict={
                    image: content_image[np.newaxis,:]

                target_feature = sess.run(target, feed_dict={
                    content: content_feature,
                    style: style_feature

            output = sess.run(decoder, feed_dict={
                content: content_feature,
                target: target_feature,
                style: style_feature

            # Grab the relevant layer outputs to see what's being minimized.
            conv3_1_out_output, conv4_1_out_output = sess.run([conv3_1_layer, conv4_1_layer], feed_dict={
                image: output

            filename = '%s_stylized_%s.%s' % (content_name, style_name, save_ext)
            filename = os.path.join(output_dir, filename)
            save_image(filename, output[0], data_format=data_format)
            print('Output image saved at', filename)

            # TODO: Change these layers.
            layersToViz = [conv3_1_out_style, conv4_1_out_style, conv3_1_out_output, conv4_1_out_output]