Ejemplo n.º 1
0
    def run_experiment(self):
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(self.init)
            self.train_writer = tf.summary.FileWriter("{}/train_logs/".format(self.log_path),
                                                      graph=tf.get_default_graph())
            self.validation_writer = tf.summary.FileWriter("{}/validation_logs/".format(self.log_path),
                                                           graph=tf.get_default_graph())
            self.train_saver = tf.train.Saver()
            self.val_saver = tf.train.Saver()

            start_from_epoch = 0
            if self.continue_from_epoch!=-1:
                start_from_epoch = self.continue_from_epoch
                checkpoint = "{}/{}_{}.ckpt".format(self.saved_models_filepath, self.experiment_name, self.continue_from_epoch)
                variables_to_restore = []
                for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                    print(var)
                    variables_to_restore.append(var)

                tf.logging.info('Fine-tuning from %s' % checkpoint)

                fine_tune = slim.assign_from_checkpoint_fn(
                    checkpoint,
                    variables_to_restore,
                    ignore_missing_vars=True)
                fine_tune(sess)

            self.iter_done = 0
            self.disc_iter = 5
            self.gen_iter = 1
            best_d_val_loss = np.inf

            if self.spherical_interpolation:
                dim = int(np.sqrt(self.num_generations)*2)
                self.z_2d_vectors = interpolations.create_mine_grid(rows=dim,
                                                                    cols=dim,
                                                                    dim=self.z_dim, space=3, anchors=None,
                                                                    spherical=True, gaussian=True)
                self.z_vectors = interpolations.create_mine_grid(rows=1, cols=self.num_generations, dim=self.z_dim,
                                                                 space=3, anchors=None, spherical=True, gaussian=True)
            else:
                self.z_vectors = np.random.normal(size=(self.num_generations, self.z_dim))
                self.z_2d_vectors = np.random.normal(size=(self.num_generations, self.z_dim))

            with tqdm.tqdm(total=self.total_epochs-start_from_epoch) as pbar_e:
                for e in range(start_from_epoch, self.total_epochs):

                    train_g_loss = []
                    val_g_loss = []
                    train_d_loss = []
                    val_d_loss = []

                    with tqdm.tqdm(total=self.total_train_batches) as pbar_train:
                        for iter in range(self.total_train_batches):

                            cur_sample = 0

                            for n in range(self.disc_iter):
                                x_train_i, x_train_j = self.data.get_train_batch()
                                x_val_i, x_val_j = self.data.get_val_batch()

                                _, d_train_loss_value = sess.run(
                                    [self.graph_ops["d_opt_op"], self.losses["d_losses"]],
                                    feed_dict={self.input_x_i: x_train_i,
                                               self.input_x_j: x_train_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: True, self.random_rotate: True})

                                d_val_loss_value = sess.run(
                                    self.losses["d_losses"],
                                    feed_dict={self.input_x_i: x_val_i,
                                               self.input_x_j: x_val_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: False, self.random_rotate: False})

                                cur_sample += 1
                                train_d_loss.append(d_train_loss_value)
                                val_d_loss.append(d_val_loss_value)

                            for n in range(self.gen_iter):
                                x_train_i, x_train_j = self.data.get_train_batch()
                                x_val_i, x_val_j = self.data.get_val_batch()
                                _, g_train_loss_value, train_summaries = sess.run(
                                    [self.graph_ops["g_opt_op"], self.losses["g_losses"],
                                     self.summary],
                                    feed_dict={self.input_x_i: x_train_i,
                                               self.input_x_j: x_train_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: True, self.random_rotate: True})

                                g_val_loss_value, val_summaries = sess.run(
                                    [self.losses["g_losses"], self.summary],
                                    feed_dict={self.input_x_i: x_val_i,
                                               self.input_x_j: x_val_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: False, self.random_rotate: False})

                                cur_sample += 1
                                train_g_loss.append(g_train_loss_value)
                                val_g_loss.append(g_val_loss_value)

                                if iter % (self.tensorboard_update_interval) == 0:
                                    self.train_writer.add_summary(train_summaries, global_step=self.iter_done)
                                    self.validation_writer.add_summary(val_summaries, global_step=self.iter_done)


                            self.iter_done = self.iter_done + 1
                            iter_out = "{}_train_d_loss: {}, train_g_loss: {}, " \
                                       "val_d_loss: {}, val_g_loss: {}".format(self.iter_done,
                                                                               d_train_loss_value, g_train_loss_value,
                                                                               d_val_loss_value,
                                                                               g_val_loss_value)
                            pbar_train.set_description(iter_out)
                            pbar_train.update(1)

                    total_d_train_loss_mean = np.mean(train_d_loss)
                    total_d_train_loss_std = np.std(train_d_loss)
                    total_g_train_loss_mean = np.mean(train_g_loss)
                    total_g_train_loss_std = np.std(train_g_loss)

                    print(
                        "Epoch {}: d_train_loss_mean: {}, d_train_loss_std: {},"
                                  "g_train_loss_mean: {}, g_train_loss_std: {}"
                        .format(e, total_d_train_loss_mean,
                                total_d_train_loss_std,
                                total_g_train_loss_mean,
                                total_g_train_loss_std))

                    total_d_val_loss_mean = np.mean(val_d_loss)
                    total_d_val_loss_std = np.std(val_d_loss)
                    total_g_val_loss_mean = np.mean(val_g_loss)
                    total_g_val_loss_std = np.std(val_g_loss)

                    print(
                        "Epoch {}: d_val_loss_mean: {}, d_val_loss_std: {},"
                        "g_val_loss_mean: {}, g_val_loss_std: {}, "
                            .format(e, total_d_val_loss_mean,
                                    total_d_val_loss_std,
                                    total_g_val_loss_mean,
                                    total_g_val_loss_std))



                    sample_generator(num_generations=self.num_generations, sess=sess, same_images=self.same_images,
                                     inputs=x_train_i,
                                     data=self.data, batch_size=self.batch_size, z_input=self.z_input,
                                     file_name="{}/train_z_variations_{}_{}.png".format(self.save_image_path,
                                                                                        self.experiment_name,
                                                                                        e),
                                     input_a=self.input_x_i, training_phase=self.training_phase,
                                     z_vectors=self.z_vectors, dropout_rate=self.dropout_rate,
                                     dropout_rate_value=self.dropout_rate_value)

                    sample_two_dimensions_generator(sess=sess,
                                                    same_images=self.same_images,
                                                    inputs=x_train_i,
                                                    data=self.data, batch_size=self.batch_size, z_input=self.z_input,
                                                    file_name="{}/train_z_spherical_{}_{}".format(self.save_image_path,
                                                                                                  self.experiment_name,
                                                                                                  e),
                                                    input_a=self.input_x_i, training_phase=self.training_phase,
                                                    dropout_rate=self.dropout_rate,
                                                    dropout_rate_value=self.dropout_rate_value,
                                                    z_vectors=self.z_2d_vectors)

                    with tqdm.tqdm(total=self.total_gen_batches) as pbar_samp:
                        for i in range(self.total_gen_batches):
                            x_gen_a = self.data.get_gen_batch()
                            sample_generator(num_generations=self.num_generations, sess=sess,
                                             same_images=self.same_images,
                                             inputs=x_gen_a,
                                             data=self.data, batch_size=self.batch_size, z_input=self.z_input,
                                             file_name="{}/test_z_variations_{}_{}_{}.png".format(self.save_image_path,
                                                                                                  self.experiment_name,
                                                                                                  e, i),
                                             input_a=self.input_x_i, training_phase=self.training_phase,
                                             z_vectors=self.z_vectors, dropout_rate=self.dropout_rate,
                                             dropout_rate_value=self.dropout_rate_value)

                            sample_two_dimensions_generator(sess=sess,
                                                            same_images=self.same_images,
                                                            inputs=x_gen_a,
                                                            data=self.data, batch_size=self.batch_size,
                                                            z_input=self.z_input,
                                                            file_name="{}/val_z_spherical_{}_{}_{}".format(
                                                                self.save_image_path,
                                                                self.experiment_name,
                                                                e, i),
                                                            input_a=self.input_x_i,
                                                            training_phase=self.training_phase,
                                                            dropout_rate=self.dropout_rate,
                                                            dropout_rate_value=self.dropout_rate_value,
                                                            z_vectors=self.z_2d_vectors)

                            pbar_samp.update(1)

                    train_save_path = self.train_saver.save(sess, "{}/train_saved_model_{}_{}.ckpt".format(
                        self.saved_models_filepath,
                        self.experiment_name, e))

                    if total_d_val_loss_mean<best_d_val_loss:
                        best_d_val_loss = total_d_val_loss_mean
                        val_save_path = self.train_saver.save(sess, "{}/val_saved_model_{}_{}.ckpt".format(
                            self.saved_models_filepath,
                            self.experiment_name, e))
                        print("Saved current best val model at", val_save_path)

                    save_statistics(self.log_path, [e, total_d_train_loss_mean, total_d_val_loss_mean,
                                                total_d_train_loss_std, total_d_val_loss_std,
                                                total_g_train_loss_mean, total_g_val_loss_mean,
                                                total_g_train_loss_std, total_g_val_loss_std])

                    pbar_e.update(1)
Ejemplo n.º 2
0
    def run_experiment(self):
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(self.init)
            self.writer = tf.summary.FileWriter(self.log_path,
                                                graph=tf.get_default_graph())
            self.saver = tf.train.Saver()
            start_from_epoch = 0
            if self.continue_from_epoch != -1:
                start_from_epoch = self.continue_from_epoch
                checkpoint = "{}/{}_{}.ckpt".format(self.saved_models_filepath,
                                                    self.experiment_name,
                                                    self.continue_from_epoch)
                variables_to_restore = []
                for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                    print(var)
                    variables_to_restore.append(var)

                tf.logging.info('Fine-tuning from %s' % checkpoint)

                fine_tune = slim.assign_from_checkpoint_fn(
                    checkpoint, variables_to_restore, ignore_missing_vars=True)
                fine_tune(sess)

            self.iter_done = 0
            self.disc_iter = 5
            self.gen_iter = 1

            if self.spherical_interpolation:
                dim = int(np.sqrt(self.num_generations) * 2)
                self.z_2d_vectors = interpolations.create_mine_grid(
                    rows=dim,
                    cols=dim,
                    dim=self.z_dim,
                    space=3,
                    anchors=None,
                    spherical=True,
                    gaussian=True)
                self.z_vectors = interpolations.create_mine_grid(
                    rows=1,
                    cols=self.num_generations,
                    dim=self.z_dim,
                    space=3,
                    anchors=None,
                    spherical=True,
                    gaussian=True)
            else:
                self.z_vectors = np.random.normal(size=(self.num_generations,
                                                        self.z_dim))
                self.z_2d_vectors = np.random.normal(
                    size=(self.num_generations, self.z_dim))

            with tqdm.tqdm(total=self.total_epochs -
                           start_from_epoch) as pbar_e:
                for e in range(start_from_epoch, self.total_epochs):

                    total_g_loss = 0.
                    total_d_loss = 0.
                    save_path = self.saver.save(
                        sess,
                        "{}/{}_{}.ckpt".format(self.saved_models_filepath,
                                               self.experiment_name, e))
                    print("Model saved at", save_path)
                    with tqdm.tqdm(
                            total=self.total_train_batches) as pbar_train:
                        x_train_a_gan_list, x_train_b_gan_same_class_list = self.data.get_train_batch(
                        )

                        sample_generator(
                            num_generations=self.num_generations,
                            sess=sess,
                            same_images=self.same_images,
                            inputs=x_train_a_gan_list,
                            data=self.data,
                            batch_size=self.batch_size,
                            z_input=self.z_input,
                            file_name="{}/train_z_variations_{}_{}.png".format(
                                self.save_image_path, self.experiment_name, e),
                            input_a=self.input_x_i,
                            training_phase=self.training_phase,
                            z_vectors=self.z_vectors,
                            dropout_rate=self.dropout_rate,
                            dropout_rate_value=self.dropout_rate_value)

                        sample_two_dimensions_generator(
                            sess=sess,
                            same_images=self.same_images,
                            inputs=x_train_a_gan_list,
                            data=self.data,
                            batch_size=self.batch_size,
                            z_input=self.z_input,
                            file_name="{}/train_z_spherical_{}_{}".format(
                                self.save_image_path, self.experiment_name, e),
                            input_a=self.input_x_i,
                            training_phase=self.training_phase,
                            dropout_rate=self.dropout_rate,
                            dropout_rate_value=self.dropout_rate_value,
                            z_vectors=self.z_2d_vectors)

                        with tqdm.tqdm(
                                total=self.total_gen_batches) as pbar_samp:
                            for i in range(self.total_gen_batches):
                                x_gen_a = self.data.get_gen_batch()
                                sample_generator(
                                    num_generations=self.num_generations,
                                    sess=sess,
                                    same_images=self.same_images,
                                    inputs=x_gen_a,
                                    data=self.data,
                                    batch_size=self.batch_size,
                                    z_input=self.z_input,
                                    file_name=
                                    "{}/test_z_variations_{}_{}_{}.png".format(
                                        self.save_image_path,
                                        self.experiment_name, e, i),
                                    input_a=self.input_x_i,
                                    training_phase=self.training_phase,
                                    z_vectors=self.z_vectors,
                                    dropout_rate=self.dropout_rate,
                                    dropout_rate_value=self.dropout_rate_value)

                                sample_two_dimensions_generator(
                                    sess=sess,
                                    same_images=self.same_images,
                                    inputs=x_gen_a,
                                    data=self.data,
                                    batch_size=self.batch_size,
                                    z_input=self.z_input,
                                    file_name="{}/val_z_spherical_{}_{}_{}".
                                    format(self.save_image_path,
                                           self.experiment_name, e, i),
                                    input_a=self.input_x_i,
                                    training_phase=self.training_phase,
                                    dropout_rate=self.dropout_rate,
                                    dropout_rate_value=self.dropout_rate_value,
                                    z_vectors=self.z_2d_vectors)

                                pbar_samp.update(1)

                        for i in range(self.total_train_batches):

                            for j in range(self.disc_iter):
                                x_train_a_gan_list, x_train_b_gan_same_class_list = self.data.get_train_batch(
                                )
                                _, d_loss_value = sess.run(
                                    [
                                        self.graph_ops["d_opt_op"],
                                        self.losses["d_losses"]
                                    ],
                                    feed_dict={
                                        self.input_x_i: x_train_a_gan_list,
                                        self.input_x_j:
                                        x_train_b_gan_same_class_list,
                                        self.dropout_rate:
                                        self.dropout_rate_value,
                                        self.training_phase: True,
                                        self.random_rotate: True
                                    })
                                total_d_loss += d_loss_value

                            for j in range(self.gen_iter):
                                x_train_a_gan_list, x_train_b_gan_same_class_list = \
                                    self.data.get_train_batch()
                                _, g_loss_value, summaries, = sess.run(
                                    [
                                        self.graph_ops["g_opt_op"],
                                        self.losses["g_losses"], self.summary
                                    ],
                                    feed_dict={
                                        self.input_x_i: x_train_a_gan_list,
                                        self.input_x_j:
                                        x_train_b_gan_same_class_list,
                                        self.dropout_rate:
                                        self.dropout_rate_value,
                                        self.training_phase: True,
                                        self.random_rotate: True
                                    })

                                total_g_loss += g_loss_value

                            if i % (self.tensorboard_update_interval) == 0:
                                self.writer.add_summary(summaries)
                            self.iter_done = self.iter_done + 1
                            iter_out = "d_loss: {}, g_loss: {}".format(
                                d_loss_value, g_loss_value)
                            pbar_train.set_description(iter_out)
                            pbar_train.update(1)

                    total_g_loss /= (self.total_train_batches * self.gen_iter)

                    total_d_loss /= (self.total_train_batches * self.disc_iter)

                    print("Epoch {}: d_loss: {}, wg_loss: {}".format(
                        e, total_d_loss, total_g_loss))

                    total_g_val_loss = 0.
                    total_d_val_loss = 0.

                    with tqdm.tqdm(total=self.total_test_batches) as pbar_val:
                        for i in range(self.total_test_batches):

                            for j in range(self.disc_iter):
                                x_test_a, x_test_b = self.data.get_test_batch()
                                d_loss_value = sess.run(
                                    self.losses["d_losses"],
                                    feed_dict={
                                        self.input_x_i: x_test_a,
                                        self.input_x_j: x_test_b,
                                        self.training_phase: False,
                                        self.random_rotate: False,
                                        self.dropout_rate:
                                        self.dropout_rate_value
                                    })

                                total_d_val_loss += d_loss_value

                            for j in range(self.gen_iter):
                                x_test_a, x_test_b = self.data.get_test_batch()
                                g_loss_value = sess.run(
                                    self.losses["g_losses"],
                                    feed_dict={
                                        self.input_x_i: x_test_a,
                                        self.input_x_j: x_test_b,
                                        self.training_phase: False,
                                        self.random_rotate: False,
                                        self.dropout_rate:
                                        self.dropout_rate_value
                                    })

                                total_g_val_loss += (g_loss_value)

                            self.iter_done = self.iter_done + 1
                            iter_out = "d_loss: {}, g_loss: {}".format(
                                d_loss_value, g_loss_value)
                            pbar_val.set_description(iter_out)
                            pbar_val.update(1)

                    total_g_val_loss /= (self.total_test_batches *
                                         self.gen_iter)
                    total_d_val_loss /= (self.total_test_batches *
                                         self.disc_iter)

                    print("Epoch {}: d_val_loss: {}, wg_val_loss: {}".format(
                        e, total_d_val_loss, total_g_val_loss))

                    save_statistics(self.log_path, [
                        e, total_d_loss, total_g_loss, total_d_val_loss,
                        total_g_val_loss
                    ])

                    pbar_e.update(1)