示例#1
0
def main(_):
    if not FLAGS.content_dataset_dir:
        raise ValueError('You must supply the content dataset directory '
                         'with --content_dataset_dir')
    if not FLAGS.style_dataset_dir:
        raise ValueError('You must supply the style dataset directory '
                         'with --style_dataset_dir')

    if not FLAGS.checkpoint_dir:
        raise ValueError('You must supply the checkpoints directory with '
                         '--checkpoint_dir')

    if tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        checkpoint_dir = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    else:
        checkpoint_dir = FLAGS.checkpoint_dir

    if not tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.MakeDirs(FLAGS.eval_dir)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # definet the model
        style_model, options = models_factory.get_model(FLAGS.model_config_path)

        # predict the stylized image
        inp_content_image = tf.placeholder(tf.float32, shape=(None, None, 3))
        inp_style_image = tf.placeholder(tf.float32, shape=(None, None, 3))

        # preprocess the content and style images
        content_image = preprocessing.preprocessing_image(
            inp_content_image, 448, 448, resize_side=512, is_training=False)
        content_image = tf.expand_dims(content_image, axis=0)
        style_image = preprocessing.preprocessing_image(
            inp_style_image, 448, 448, resize_side=512, is_training=False)
        style_image = tf.expand_dims(style_image, axis=0)

        # style transfer
        stylized_image = style_model.style_transfer(content_image, style_image)
        stylized_image = tf.squeeze(stylized_image, axis=0)

        # gather the test image filenames and style image filenames
        style_image_filenames = utils.get_image_filenames(FLAGS.style_dataset_dir)
        content_image_filenames = utils.get_image_filenames(FLAGS.content_dataset_dir)

        # starting inference of the images
        init_fn = slim.assign_from_checkpoint_fn(
          checkpoint_dir, slim.get_model_variables(), ignore_missing_vars=True)
        with tf.Session() as sess:
            # initialize the graph
            init_fn(sess)

            # style transfer for each image based on one style image
            for i in range(len(style_image_filenames)):
                # gather the storage folder for the style transfer
                style_label = style_image_filenames[i].split('/')[-1]
                style_label = style_label.split('.')[0]
                style_dir = os.path.join(FLAGS.eval_dir, style_label)

                if not tf.gfile.Exists(style_dir):
                    tf.gfile.MakeDirs(style_dir)

                # get the style image
                np_style_image = utils.image_reader(style_image_filenames[i])
                print('Starting transferring the style of [%s]' % style_label)

                for j in range(len(content_image_filenames)):
                    # gather the content image
                    np_content_image = utils.image_reader(content_image_filenames[j])
                    np_stylized_image = sess.run(stylized_image,
                                                 feed_dict={inp_content_image: np_content_image,
                                                            inp_style_image: np_style_image})

                    output_filename = os.path.join(
                        style_dir, content_image_filenames[j].split('/')[-1])
                    utils.imsave(output_filename, np_stylized_image)
                    print('Style [%s]: Finish transfer the image [%s]' % (
                        style_label, content_image_filenames[j]))
    img_io = io.BytesIO()
    pil_img.save(img_io, 'JPEG', quality=70)
    img_io.seek(0)
    return send_file(img_io, mimetype='image/jpeg')


def imsave(filename, img):
    img = np.clip(img, 0, 255).astype(np.uint8)
    return Image.fromarray(img)


checkpoint_dir = tf.train.latest_checkpoint("AvatarNet")

tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
    style_model, options = models_factory.get_model("AvatarNet_config.yml")

    # predict the stylized image
    inp_content_image = tf.placeholder(tf.float32, shape=(None, None, 3))
    inp_style_image = tf.placeholder(tf.float32, shape=(None, None, 3))

    # preprocess the content and style images
    content_image = preprocessing.mean_image_subtraction(inp_content_image)
    content_image = tf.expand_dims(content_image, axis=0)
    # style resizing and cropping
    style_image = preprocessing.preprocessing_image(inp_style_image, 448, 448,
                                                    style_model.style_size)
    style_image = tf.expand_dims(style_image, axis=0)

    # style transfer
    stylized_image = style_model.transfer_styles(
示例#3
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with'
                         ' --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        global_step = slim.create_global_step()  # create the global step

        ######################
        # select the dataset #
        ######################
        dataset = dataset_utils.get_split(FLAGS.dataset_name,
                                          FLAGS.dataset_split_name,
                                          FLAGS.dataset_dir)

        ######################
        # create the network #
        ######################
        # parse the options from a yaml file
        model, options = models_factory.get_model(FLAGS.model_config)

        ####################################################
        # create a dataset provider that loads the dataset #
        ####################################################
        # dataset provider
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
        [image] = provider.get(['image'])
        image_clip = preprocessing_image(image,
                                         model.training_image_size,
                                         model.training_image_size,
                                         model.content_size,
                                         is_training=True)
        image_clip_batch = tf.train.batch(
            [image_clip],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        # feque queue the inputs
        batch_queue = slim.prefetch_queue.prefetch_queue([image_clip_batch])

        ###########################################
        # build the models based on the given data #
        ###########################################
        images = batch_queue.dequeue()
        total_loss = model.build_train_graph(images)

        ####################################################
        # gather the operations for training and summaries #
        ####################################################
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # configurate the moving averages
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # gather the optimizer operations
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # training operations
        train_op = model.get_training_operations(
            optimizer, global_step, _get_variables_to_train(options))
        update_ops.append(train_op)

        # gather the training summaries
        summaries |= set(model.summaries)

        # gather the update operation
        update_op = tf.group(*update_ops)
        watched_loss = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # merge the summaries
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ##############################
        # start the training process #
        ##############################
        slim.learning.train(watched_loss,
                            logdir=FLAGS.train_dir,
                            init_fn=_get_init_fn(options),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
def main(_):
    if not FLAGS.content_dataset_dir:
        raise ValueError('You must supply the content dataset directory '
                         'with --content_dataset_dir')
    if not FLAGS.style_dataset_dir:
        raise ValueError('You must supply the style dataset directory '
                         'with --style_dataset_dir')

    if not FLAGS.checkpoint_dir:
        raise ValueError('You must supply the checkpoints directory with '
                         '--checkpoint_dir')

    if tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        checkpoint_dir = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    else:
        checkpoint_dir = FLAGS.checkpoint_dir

    if not tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.MakeDirs(FLAGS.eval_dir)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # define the model
        print(FLAGS.model_config_path)
        style_model, options = models_factory.get_model(
            FLAGS.model_config_path)

        # predict the stylized image
        inp_content_image = tf.placeholder(tf.float32, shape=(None, None, 3))
        inp_style_image = tf.placeholder(tf.float32, shape=(None, None, 3))

        # preprocess the content and style images
        content_image = preprocessing.mean_image_subtraction(inp_content_image)
        content_image = tf.expand_dims(content_image, axis=0)
        # style resizing and cropping
        style_image = preprocessing.preprocessing_image(
            inp_style_image, 448, 448, style_model.style_size)
        style_image = tf.expand_dims(style_image, axis=0)

        # style transfer
        stylized_image = style_model.transfer_styles(
            content_image, style_image, inter_weight=FLAGS.inter_weight)
        stylized_image = tf.squeeze(stylized_image, axis=0)

        # gather the test image filenames and style image filenames
        style_image_filenames = get_image_filenames(FLAGS.style_dataset_dir)
        content_image_filenames = get_image_filenames(
            FLAGS.content_dataset_dir)

        # starting inference of the images
        # init_fn = slim.assign_from_checkpoint_fn(
        #   checkpoint_dir, slim.get_model_variables(), ignore_missing_vars=True)
        init_fn = slim.assign_from_checkpoint_fn(
            "/data/jsy/code/avatar-net/pretrained_models/model.ckpt-120000",
            slim.get_model_variables(),
            ignore_missing_vars=True)
        with tf.Session() as sess:
            # initialize the graph
            init_fn(sess)

            nn = 0.0
            total_time = 0.0
            # style transfer for each image based on one style image
            for i in range(len(style_image_filenames)):
                # gather the storage folder for the style transfer
                style_label = style_image_filenames[i].split('/')[-1]
                style_label = style_label.split('.')[0]
                # style_dir = os.path.join(FLAGS.eval_dir, style_label)

                if not tf.gfile.Exists(FLAGS.eval_dir):
                    tf.gfile.MakeDirs(FLAGS.eval_dir)

                # get the style image
                np_style_image = image_reader(style_image_filenames[i])
                print('Starting transferring the style of [%s]' % style_label)

                for j in range(len(content_image_filenames)):
                    # gather the content image
                    np_content_image = image_reader(content_image_filenames[j])

                    content_label = content_image_filenames[j].split('/')[-1]
                    content_label = content_label.split('.')[0]

                    start_time = time.time()
                    np_stylized_image = sess.run(stylized_image,
                                                 feed_dict={
                                                     inp_content_image:
                                                     np_content_image,
                                                     inp_style_image:
                                                     np_style_image
                                                 })
                    incre_time = time.time() - start_time
                    nn += 1.0
                    total_time += incre_time
                    print("---%s seconds ---" % (total_time / nn))

                    output_filename = os.path.join(
                        FLAGS.eval_dir, f'{content_label}_{style_label}.bmp')
                    imsave(output_filename, np_stylized_image)
                    print('Style [%s]: Finish transfer the image [%s]' %
                          (style_label, content_image_filenames[j]))
示例#5
0
文件: main.py 项目: devkook/ChangeGAN
def run(target,
        is_chief,
        eval,
        eval_output_dir,
        eval_output_bucket,
        train_steps,
        eval_steps,
        job_dir,
        learning_rate,
        eval_frequency,
        dataset_name,
        model_name,
        domain_a,
        domain_b,
        train_dir,
        eval_dir,
        train_batch_size,
        eval_batch_size):
    ######################
    # Select the dataset #
    ######################
    train_dataset_a = dataset_factory.get_dataset(
        dataset_name, domain_a, train_dir)
    train_dataset_b = dataset_factory.get_dataset(
        dataset_name, domain_b, train_dir)
    eval_dataset_a = dataset_factory.get_dataset(
        dataset_name, domain_a, eval_dir)
    eval_dataset_b = dataset_factory.get_dataset(
        dataset_name, domain_b, eval_dir)

    # If the server is chief which is `master`
    # In between graph replication Chief is one node in
    # the cluster with extra responsibility and by default
    # is worker task zero. We have assigned master as the chief.
    #
    # See https://youtu.be/la_M6bCV91M?t=1203 for details on
    # distributed TensorFlow and motivation about chief.
    if not eval and is_chief:
        # Do evaluation job
        evaluation_graph = tf.Graph()
        with evaluation_graph.as_default():
            # Inputs
            images_a, images_b, bboxes_a, bboxes_b = change_gan.input_fn(
                eval_dataset_a, eval_dataset_b,
                batch_size=eval_batch_size, is_training=False)

            # Model
            outputs = change_gan.model_fn(
                images_a, images_b, learning_rate, is_training=False)

        hooks = [EvaluationRunHook(
            job_dir,
            evaluation_graph,
            eval_frequency,
            eval_steps=eval_steps,
        )]
    else:
        hooks = []

    if eval:
        # Do evaluation job
        evaluation_graph = tf.Graph()
        with evaluation_graph.as_default():
            # Inputs
            images_a, images_b, bboxes_a, bboxes_b = change_gan.input_fn(
                eval_dataset_a, eval_dataset_b,
                batch_size=eval_batch_size, is_training=False)

            # Model
            outputs = change_gan.model_fn(
                images_a, images_b, learning_rate, is_training=False)

            # Saver class add ops to save and restore
            # variables to and from checkpoint
            saver = tf.train.Saver()
            # Creates a global step to contain a counter for
            # the global training step
            gs = tf.train.get_or_create_global_step()

        """Run model evaluation and generate summaries."""
        coord = tf.train.Coordinator(clean_stop_exception_types=(
            tf.errors.CancelledError, tf.errors.OutOfRangeError))

        with tf.Session(graph=evaluation_graph) as session:
            # Restores previously saved variables from latest checkpoint
            saver.restore(session, tf.train.latest_checkpoint(job_dir))

            session.run([
                tf.tables_initializer(),
                tf.local_variables_initializer()
            ])
            threads = tf.train.start_queue_runners(coord=coord, sess=session)
            train_step = session.run(gs)

            tf.logging.info('Starting Evaluation For Step: {}'.format(train_step))

            client = storage.Client()
            # TODO: bucket to arg
            bucket = client.get_bucket(eval_output_bucket)

            with coord.stop_on_exception():
                eval_step = 0
                while not coord.should_stop() and (eval_steps is None or
                                                           eval_step < eval_steps):
                    inputs_a, inputs_b, outputs_ba, outputs_ab, outputs_aba, outputs_bab = session.run(
                        outputs)
                    if eval_step % 100 == 0:
                        tf.logging.info("On Evaluation Step: {}".format(eval_step))
                    eval_step += 1

                    def save_to_gs(image, filepath):
                        blob = bucket.blob(filepath)
                        outfile = '/tmp/img.jpg'
                        image.save(outfile)
                        of = open(outfile, 'rb')
                        blob.upload_from_file(of)

                    # TODO: Save results
                    for i in range(eval_batch_size):
                        inputs_a_img = Image.fromarray(inputs_a[i])
                        inputs_b_img = Image.fromarray(inputs_b[i])
                        outputs_ba_img = Image.fromarray(outputs_ba[i])
                        outputs_ab_img = Image.fromarray(outputs_ab[i])

                        # TODO: Eval output as arg
                        save_to_gs(inputs_a_img,
                                   os.path.join(eval_output_dir, str(train_step),
                                                'inputs_a_{}_{}.jpg'.format(eval_step, i)))
                        save_to_gs(inputs_b_img,
                                   os.path.join(eval_output_dir, str(train_step),
                                                'inputs_b_{}_{}.jpg'.format(eval_step, i)))
                        save_to_gs(outputs_ba_img,
                                   os.path.join(eval_output_dir, str(train_step),
                                                'outputs_ba_{}_{}.jpg'.format(eval_step, i)))
                        save_to_gs(outputs_ab_img,
                                   os.path.join(eval_output_dir, str(train_step),
                                                'outputs_ab_{}_{}.jpg'.format(eval_step, i)))

            coord.request_stop()
            coord.join(threads)
        return

    with tf.Graph().as_default():
        # Placement of ops on devices using replica device setter
        # which automatically places the parameters on the `ps` server
        # and the `ops` on the workers
        #
        # See:
        # https://www.tensorflow.org/api_docs/python/tf/train/replica_device_setter
        with tf.device(tf.train.replica_device_setter()):
            model = models_factory.get_model(model_name)
            # Inputs
            images_a, images_b, bboxes_a, bboxes_b = model.input_fn(
                train_dataset_a, train_dataset_b,
                batch_size=train_batch_size, is_training=True)

            # Model
            train_op, global_step, outputs = model.model_fn(
                images_a, images_b, learning_rate, is_training=True)

        # Creates a MonitoredSession for training
        # MonitoredSession is a Session-like object that handles
        # initialization, recovery and hooks
        # https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
        with tf.train.MonitoredTrainingSession(master=target,
                                               is_chief=is_chief,
                                               checkpoint_dir=job_dir,
                                               hooks=hooks,
                                               save_checkpoint_secs=60,
                                               save_summaries_steps=50) as session:
            # Global step to keep track of global number of steps particularly in
            # distributed setting
            step = global_step.eval(session=session)

            # Run the training graph which returns the step number as tracked by
            # the global step tensor.
            # When train epochs is reached, session.should_stop() will be true.
            while (train_steps is None or step < train_steps) \
                    and not session.should_stop():
                step, _ = session.run([global_step, train_op])
示例#6
0
def main(datasetname, n_classes, batch_size, model_name, kernel, cp, dp, gamma,
         pooling_method, epochs, lr, keep_prob, weight_decay, base_log_dir):

    #Command for GPU version.
    config = ConfigProto()
    config.gpu_options.allow_growth = True
    session = InteractiveSession(config=config)

    #Save all statistics for both training/validation
    STAT_FILE = './statistics.txt'

    if os.path.exists(STAT_FILE):
        os.remove(STAT_FILE)

    #Fix TF random seed
    tf.random.set_seed(1777)
    log_dir = os.path.join(os.path.expanduser(base_log_dir),
                           "{}".format(datasetname))
    os.makedirs(log_dir, exist_ok=True)

    # dataset
    train_dataset, train_samples = datasets.get_dataset(
        datasetname, batch_size)
    test_dataset, _ = datasets.get_dataset(datasetname,
                                           batch_size,
                                           subset="test",
                                           shuffle=False)

    #Network
    kernel_fn = get_kernel(kernel, cp=cp, dp=dp, gamma=gamma)

    model = models_factory.get_model(model_name,
                                     num_classes=n_classes,
                                     keep_prob=keep_prob,
                                     kernel_fn=kernel_fn,
                                     pooling=pooling_method)

    #Train optimizer, loss
    nrof_steps_per_epoch = (train_samples // batch_size)
    boundries = [nrof_steps_per_epoch * 75, nrof_steps_per_epoch * 125]
    values = [lr, lr * 0.1, lr * 0.01]
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\
                    boundries,
                    values)
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule,
                                        momentum=0.9,
                                        decay=weight_decay)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

    #metrics
    train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    test_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

    #Train step
    @tf.function
    def train_step(x, labels):
        with tf.GradientTape() as t:
            logits = model(x, training=True)
            loss = loss_fn(labels, logits)

        gradients = t.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return loss, logits

    #Run

    ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

    #Summary writers
    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(log_dir, 'summaries', 'train'))
    test_summary_writer = tf.summary.create_file_writer(
        os.path.join(log_dir, 'summaries', 'test'))

    ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                               optimizer=optimizer,
                               net=model)
    ckpt_path = os.path.join(log_dir, 'checkpoints')
    manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    test_accuracy_collector = []

    for ep in tqdm.trange(epochs, desc='Epoch Loop'):
        if ep < ep_cnt:
            continue

            # update epoch counter
        ep_cnt.assign_add(1)
        with train_summary_writer.as_default():
            # train for an epoch
            for step, (x, y) in enumerate(train_dataset):
                if len(x.shape) == 3:
                    x = tf.expand_dims(x, 3)
                tf.summary.image("input_image", x, step=optimizer.iterations)
                loss, logits = train_step(x, y)
                train_acc_metric(y, logits)
                ckpt.step.assign_add(1)
                tf.summary.scalar("loss", loss, step=optimizer.iterations)

                if int(ckpt.step) % 1000 == 0:
                    save_path = manager.save()
                    print("Saved checkpoint for step {}: {}".format(
                        int(ckpt.step), save_path))
                # Log every 25 batch
                if step % 200 == 0:
                    '''
                    for x_batch, y_batch in test_dataset:
                        if len(x_batch.shape)==3:
                            x_batch = tf.expand_dims(x_batch, 3)
                        test_logits = model(x_batch, training=False)
                        # Update test metrics
                        test_acc_metric(y_batch, test_logits)

                    test_acc = test_acc_metric.result()
                    test_accuracy_collector.append(float(test_acc))
                    test_acc_metric.reset_states()
                    '''
                    train_acc = train_acc_metric.result()
                    print("Training loss {:1.2f}, accuracu {} at step {}".format(\
                            loss.numpy(),
                            float(train_acc),
                            step))

            # Display metrics at the end of each epoch.
            train_acc = train_acc_metric.result()
            tf.summary.scalar("accuracy", train_acc, step=ep)
            print('Training acc over epoch: %s' % (float(train_acc), ))
            # Reset training metrics at the end of each epoch
            train_acc_metric.reset_states()

    ############################## Test the model #############################
        with test_summary_writer.as_default():
            for x_batch, y_batch in test_dataset:
                if len(x_batch.shape) == 3:
                    x_batch = tf.expand_dims(x_batch, 3)
                test_logits = model(x_batch, training=False)
                # Update test metrics
                test_acc_metric(y_batch, test_logits)

            test_acc = test_acc_metric.result()
            tf.summary.scalar("accuracy", test_acc, step=ep)
            test_acc_metric.reset_states()
            test_accuracy_collector.append(float(test_acc))
            print('[Epoch {}] Test acc: {}'.format(ep, float(test_acc)))

    print("Best test accuracy is: {}".format(max(test_accuracy_collector)))
    #Save stats:
    with open(STAT_FILE, 'wb') as fp:
        pickle.dump(test_accuracy_collector, fp)
示例#7
0
def main(_):
    if not FLAGS.style_dataset_dir:
        raise ValueError('You must supply the style dataset directory '
                         'with --style_dataset_dir')

    if not FLAGS.checkpoint_dir:
        raise ValueError('You must supply the checkpoints directory with '
                         '--checkpoint_dir')

    if tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        checkpoint_dir = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    else:
        checkpoint_dir = FLAGS.checkpoint_dir

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # define the model
        style_model, options = models_factory.get_model(
            FLAGS.model_config_path)

        # predict the stylized image
        inp_style_image = tf.placeholder(tf.float32, shape=(None, None, 3))
        # style_image = preprocessing.mean_image_subtraction(inp_style_image)
        style_image = preprocessing.preprocessing_image(inp_style_image,
                                                        224,
                                                        224,
                                                        resize_side=256,
                                                        is_training=False)
        style_image = tf.expand_dims(style_image, axis=0)

        # style transfer with the randomized content
        stylized_image, _ = style_model.texture_generation(style_image)
        stylized_image = tf.squeeze(stylized_image, axis=0)

        # gather the test image filenames and style image filenames
        styles_categories = utils.get_image_filenames(FLAGS.style_dataset_dir)

        # starting inference of the images
        init_fn = slim.assign_from_checkpoint_fn(checkpoint_dir,
                                                 slim.get_model_variables(),
                                                 ignore_missing_vars=True)
        with tf.Session() as sess:
            # initialize the graph
            init_fn(sess)

            # style transfer for each image based on one style image
            for style_category in styles_categories:
                # gather the storage folder for the style transfer
                category = style_category.split('/')[-1]
                style_dir = os.path.join(FLAGS.eval_dir, category)
                if not tf.gfile.Exists(style_dir):
                    tf.gfile.MakeDirs(style_dir)
                sys.stdout.write('>> Transferring the texture [%s] ' %
                                 category)

                style_image_filenames = utils.get_image_filenames(
                    style_category)
                for i in range(len(style_image_filenames)):
                    np_style_image = utils.image_reader(
                        style_image_filenames[i])

                    np_stylized_image = sess.run(
                        stylized_image,
                        feed_dict={inp_style_image: np_style_image})
                    output_filename = os.path.join(
                        style_dir, style_image_filenames[i].split('/')[-1])
                    utils.imsave(output_filename, np_stylized_image)
                    sys.stdout.write('\r>> Transferring image %d/%d' %
                                     (i + 1, len(style_image_filenames)))
                    sys.stdout.flush()
def run(args,
        vis=False,
        save_vis=False,
        save_model=False,
        save_cm=False,
        phase='train'):
    logging.info("Running new experiment\n========================\n")

    data_params = args['data_params']
    model_params = args['model_params']
    pre_params = args['pre_params']
    logging.info(args)

    # get data
    if phase == 'train':
        X, y = get_data(data_params, phase)
        # split data
        X_train, X_test, y_train, y_test = train_test_split(X,
                                                            y,
                                                            test_size=0.3)
    else:
        X_train, y_train = get_data(data_params, 'train')
        X_test, y_test = get_data(data_params, 'test')

    logging.info('y shape: %s', y_train.shape)
    logging.info('x shape: %s', X_train.shape)

    # get model
    model = get_model(model_params)

    if save_model:
        model_type = args['model_params']['type']
        with open('../results/bestModels/' + model_type + '.pkl',
                  'wb') as output:
            pickle.dump(model, output, pickle.HIGHEST_PROTOCOL)

    # preprocessing
    proc = pre.get_processor(pre_params)

    if proc:
        proc.fit(X_train)
        X_train = proc.transform(X_train)
        X_test = proc.transform(X_test)
    else:
        print('no preprocessing applied')

    logging.info('fitting model started ....')
    model.fit(X_train, y_train)
    logging.info('model fitting finished')

    pred_test = model.predict(X_test)
    pred_train = model.predict(X_train)

    score_train = hamming_loss(y_train, np.int32(pred_train))
    score_test = hamming_loss(y_test, np.int32(pred_test))

    logging.info('score_test: %s', score_test)
    logging.info('score_train: %s', score_train)

    if save_cm:
        model_type = args['model_params']['type']
        # Compute confusion matrix
        cm = confusion_matrix(y_test, pred_test)
        target_names = np.unique(y_test)
        np.save('../results/confusionMatrix/' + model_type + '_cm_5.npy', cm)
        np.save('../results/confusionMatrix/' + model_type + '_names_5.npy',
                target_names)

    if vis:
        scores = [score_train, score_test]
        data = [y_train, y_test, pred_train, pred_test]
        visualize(data, scores, args, save_vis=save_vis)

    return score_test