def train(self, counter=1, gen_dirs=()):
        if self.conf.need_to_load:
            self.load(self.conf.checkpoint_dir, step=counter)

        data = self.data
        logger.info('Total amount of images: %s' % len(data))
        # np.random.shuffle(data)

        tf.initialize_all_variables().run()

        # counter = 1
        start_time = time.time()
        batch_idxs = min(len(data), self.conf.train_size) / self.conf.batch_size

        stego_accuracy = 0

        accuracies = []
        accuracies_steps = []

        logger.debug('Starting updating')
        for epoch in range(self.conf.epoch):
            losses = []

            np.random.shuffle(data)

            logger.info('Starting epoch %s' % epoch)

            for idx in range(0, int(batch_idxs)):
                batch_files = data[idx * self.conf.batch_size:(idx + 1) * self.conf.batch_size]
                batch = [get_image(batch_file, self.conf.image_size)
                         for batch_file in batch_files]
                batch_images = np.array(batch).astype(np.float32)

                batch_targets = self.get_targets(batch_files)

                self.sess.run(self.optimize, feed_dict={self.images: batch_images, self.target: batch_targets})
                loss = self.loss.eval({self.images: batch_images, self.target: batch_targets})

                losses.append(loss)

                # logger.debug("[ITERATION] Epoch [%2d], iteration [%4d/%4d] time: %4.4f, Loss: %8f, accuracy: %8f" %
                #              (epoch, idx, batch_idxs, time.time() - start_time, loss, stego_accuracy))

                counter += 1

                if counter % 300 == 0:
                    logger.info('------')

                    stego_accuracy = self.accuracy(n_files=-1, test_dir=self.test_dir)
                    logger.info('[TEST] Epoch {:2d} accuracy: {:3.1f}%'.format(epoch + 1, 100 * stego_accuracy))

                    for gen_dir in gen_dirs:
                        gen_accuracy = self.accuracy(n_files=-1, test_dir=gen_dir)
                        logger.info('[GEN_TEST] Folder {}, accuracy: {:3.1f}%'.format(gen_dir, 100 * gen_accuracy))
    def accuracy(self, test_dir='test', abs=False, n_files=2 ** 12):
        logger.info('[TEST], test data folder: %s, n_files: %s' % (test_dir, 2 * n_files))
        X_test = self.get_images_names('%s/*.%s' % (test_dir, self.conf.img_format), abs=abs)[:n_files]

        accuracies = []

        batch_idxs = min(len(X_test), self.conf.train_size) / self.conf.batch_size

        # logger.debug('Starting iteration')
        for idx in range(0, int(batch_idxs)):
            batch_files_stego = X_test[idx * self.conf.batch_size:(idx + 1) * self.conf.batch_size]
            batch = [get_image(batch_file, self.conf.image_size) for batch_file in batch_files_stego]
            batch_images = np.array(batch).astype(np.float32)

            batch_targets = self.get_targets(batch_files_stego)

            accuracies.append(self.get_accuracy(batch_images, batch_targets))

        return np.mean(accuracies)
    def train(self):
        if self.conf.need_to_load:
            self.load(self.conf.checkpoint_dir)

        data = glob(os.path.join(self.conf.data, "*.%s" % self.conf.img_format))
        logger.info('Total amount of images: %s' % len(data))
        # np.random.shuffle(data)

        d_fr_optim = tf.train.AdamOptimizer(self.conf.learning_rate, beta1=self.conf.beta1)
        d_fr_optim = d_fr_optim.minimize(self.d_fr_loss, var_list=self.d_fr_vars)

        d_s_n_optim = tf.train.AdamOptimizer(self.conf.learning_rate, beta1=self.conf.beta1)
        d_s_n_optim = d_s_n_optim.minimize(self.d_stego_loss_total, var_list=self.d_s_n_vars)

        g_optim_fake = tf.train.AdamOptimizer(self.conf.learning_rate, beta1=self.conf.beta1)
        g_optim_fake = g_optim_fake.minimize(self.g_loss, var_list=self.g_vars)

        # g_optim_stego = tf.train.AdamOptimizer(0.000005, beta1=0.9)
        # g_optim_stego = g_optim_stego.minimize(self.g_loss_stego, var_list=self.g_vars)

        merged = tf.merge_all_summaries()
        train_writer = tf.train.SummaryWriter('./logs_sgan', self.sess.graph)

        tf.initialize_all_variables().run()

        sample_z = np.random.uniform(-1, 1, size=(self.sample_size, self.z_dim))
        sample_files = data[0:self.sample_size]
        sample = [get_image(sample_file, self.image_size, need_transform=True) for sample_file in sample_files]
        sample_images = np.array(sample).astype(np.float32)

        counter = 1
        start_time = time.time()
        batch_idxs = min(len(data), self.conf.train_size) / self.conf.batch_size

        logger.debug('Starting updating')
        for epoch in range(self.conf.epoch):
            stego_losses, fake_real_losses, generator_losses = [], [], []

            logger.info('Starting epoch %s' % epoch)

            for idx in range(0, int(batch_idxs)):
                batch_files = data[idx * self.conf.batch_size:(idx + 1) * self.conf.batch_size]
                batch = [get_image(batch_file, self.image_size, need_transform=True) for batch_file in batch_files]
                batch_images = np.array(batch).astype(np.float32)

                batch_z = np.random.uniform(-1, 1, [self.conf.batch_size, self.z_dim]).astype(np.float32)

                self.sess.run(d_fr_optim, feed_dict={self.images: batch_images, self.z: batch_z})
                self.sess.run(d_s_n_optim, feed_dict={self.images: batch_images, self.z: batch_z})

                self.sess.run(g_optim_fake, feed_dict={self.z: batch_z})
                self.sess.run(g_optim_fake, feed_dict={self.z: batch_z})

                # # if epoch > 5:
                # self.sess.run(g_optim_stego, feed_dict={self.z: batch_z})

                # errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                # errD_real = self.d_loss_real.eval({self.images: batch_images})
                #
                # errD_stego = self.d_loss_stego.eval({self.z: batch_z})
                # errD_n_stego = self.d_loss_nonstego.eval({self.z: batch_z})
                #
                # errG = self.g_loss.eval({self.z: batch_z})
                #
                # fake_real_losses.append(errD_fake + errD_stego)
                # stego_losses.append(errD_stego + errD_n_stego)
                # generator_losses.append(errG)
                #
                logger.debug("[ITERATION] Epoch [%2d], iteration [%4d/%4d] time: %4.4f" %
                             (epoch, idx, batch_idxs, time.time() - start_time))
                # logger.debug('[LOSS] Real/Fake: %.8f' % (errD_fake + errD_real))
                # logger.debug('[LOSS] Stego/Non-Stego: %.8f' % (errD_stego + errD_n_stego))
                # logger.debug('[LOSS] Generator: %.8f' % errG)

                counter += 1

                if np.mod(counter, 1000) == 0:
                    self.save(self.conf.checkpoint_dir, counter)

                if np.mod(counter, 300) == 0:
                    logger.info('Save samples')
                    samples, d_loss, g_loss = self.sess.run(
                        [self.sampler, self.d_fr_loss, self.g_loss_fake,
                         ],
                        feed_dict={self.z: sample_z, self.images: sample_images}
                    )
                    save_images_to_one(samples, [8, 8], './samples/train_%s_%s.png' % (epoch, idx))