def get_ssim(gold, samples):
    samples = np.array(samples)
    ssims = np.zeros((samples.shape[0], gold.shape[0]))
    pbar = Progbar(target=samples.shape[0] * gold.shape[0])
    count = 0
    for i in range(0, samples.shape[0]):
        img = samples[i]
        for j in range(0, gold.shape[0]):
            gold_img = gold[j]
            score = ssim(img, gold_img, multichannel=True)
            ssims[i, j] = score
            count += 1
            pbar.update(count)
    ssims_averages = np.mean(ssims, axis=0)
    ssims_averages = np.mean(ssims_averages, axis=0)
    ssims_std = np.std(ssims)
    print("SSIM Mean:", ssims_averages)
    print("SSIM Std:", ssims_std)
Example #2
0
    def train(self,
              train_dataset,
              valid_dataset,
              num_epochs=10,
              batch_size=1000,
              eta=1e-3,
              augmentation=False,
              saving=False,
              verbose=False):
        print('Training Model...')
        best_valid_loss = float('inf')
        for epoch in range(1, num_epochs + 1):
            if epoch > 1: eta /= 2
            # Initialize the progress bar
            num_batches = int(np.ceil(train_dataset.size() / batch_size))
            progbar = Progbar(target=num_batches)

            # Train on all batches for this epoch
            print('Epoch #{0} out of {1}: '.format(epoch, num_epochs))
            for batch, (X_batch, y_batch) in enumerate(
                    train_dataset.get_batches(batch_size, augmentation)):
                train_loss, _ = self.session.run(
                    (self.model_loss, self.train_op), {
                        self.X_placeholder: X_batch,
                        self.y_placeholder: y_batch,
                        self.learning_rate: eta,
                        self.is_training: True
                    })
                progbar.update(batch + 1, [('Train Loss', train_loss)])

            valid_loss = self.get_loss(valid_dataset, batch_size)
            marker = ""
            if valid_loss <= best_valid_loss:
                best_valid_loss = valid_loss
                self.save("saved_models/{0}/{0}".format(type(self).__name__))
                marker = "*"

            print('Validation Loss: {0:.4f} {1}'.format(valid_loss, marker))
            if verbose:
                print(self.get_stats_table(valid_dataset, batch_size))
        print('Done Training.')
Example #3
0
def main_tl(api_key, params, FLAGS, start_page=1):
    params['with_genres'] = FLAGS.genre
    params['primary_release_date.gte'] = '1980'
    max_pages = FLAGS.max_pages
    genre_list, genre_map = get_genres(API_KEY)
    total_pages = make_request(api_key, params)['total_pages']
    total_pages = min(total_pages, max_pages)
    pbar = Progbar(target=total_pages)
    num_images = 0
    for page in range(start_page, total_pages + 1):
        results = make_request(api_key, params, page)
        for movie in results['results']:
            poster = get_poster(movie['poster_path'], 'w92')
            if poster is None:
                continue
            genre = genre_list[genre_map[FLAGS.genre]]
            im = Image.fromarray(poster)
            filename = FLAGS.image_dir + "/" + genre + movie['poster_path']
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            im.save(filename)
        pbar.update(page)
Example #4
0
 def output(self, sess, inputs_raw):
     """
     Reports the output of the model on examples (uses helper to featurize each example).
     """
     inputs = []
     preds = []
     val_loss = []
     prog = Progbar(target=1 +
                    int(len(inputs_raw) / self.config.batch_size))
     for i, batch in enumerate(
             minibatches(inputs_raw, self.config.batch_size,
                         shuffle=False)):
         inputs_, labels = self.preprocess_speech_data(batch)
         preds_, _loss, _summary = self.predict_on_batch(
             sess, inputs_, labels)
         preds += list(preds_)
         inputs += list(inputs_)
         prog.update(i + 1, [])
         self.val_writer.add_summary(_summary, i)
         val_loss.append(_loss)
     logger.info("Mean Val loss  = %.2f ", np.mean(val_loss))
     return self.consolidate_predictions(inputs_raw, inputs, preds)
Example #5
0
    def fit(self, sess, saver, train_examples_raw, dev_set_raw):
        best_score = 0.
        step = 0

        for epoch in range(self.config.n_epochs):
            logger.info("Epoch %d out of %d", epoch + 1, self.config.n_epochs)
            n_minibatches = 1 + int(
                len(train_examples_raw) / self.config.batch_size)
            prog = Progbar(target=n_minibatches)
            epoch_loss = []
            for i, batch in enumerate(
                    minibatches(train_examples_raw, self.config.batch_size)):
                inputs, labels = self.preprocess_speech_data(batch)
                loss, _summary = self.train_on_batch(sess, inputs, labels)
                prog.update(i + 1, [("loss", loss)])
                self.train_writer.add_summary(_summary, step)
                epoch_loss.append(loss)
                step += 1

            logger.info("Epoc loss after epoch %d = %.2f", epoch,
                        np.mean(epoch_loss))

            logger.info("Evaluating on development data")
            language_cm = self.evaluate(sess, dev_set_raw)
            logger.debug("Lang-level confusion matrix:\n" +
                         language_cm.as_table())
            logger.debug("Lang-level scores:\n" + language_cm.summary())
            #logger.info("Entity level P/R/F1: %.2f/%.2f/%.2f", *entity_scores)

            #score = entity_scores[-1]

            #if score > best_score:
            #    best_score = score
            #    if saver:
            #        logger.info("New best score! Saving model in %s", self.config.model_output)
            #        saver.save(sess, self.config.model_output)
            print("")
        return best_score
Example #6
0
	def train(self, X, num_epochs=5): 
		train_batches = self.get_batches(X)
		# print("Train batches!")
		# print(train_batches[1])

		for epoch in range(1, num_epochs+1):
			progbar = Progbar(target = len(train_batches)) 

			print('Epoch #{0} out of {1}: '.format(epoch, num_epochs))
			for batch, (x, y) in enumerate(train_batches):
				train_loss, _ = self.session.run([self.loss, self.train_op], {self.x: x, self.y: y})

				progbar.update(batch+1, [('Train Loss', train_loss)])
			print('Training Loss: {0:.4f} {1}'.format(train_loss, '*'))


		print("All Epochs complete. Sampling.")
		joke = self.sample(500)
		print(joke)

		# print(self.get_stats_table(X_test, X_char_test, X_mask_test, y_test, features_test))

		print("Done training.")
Example #7
0
    def fit(self,
            sess,
            name,
            num_epochs=5,
            checkpoint_directory=None,
            show_every=1,
            print_every=1,
            load=False):
        self.name = name
        self.total_g_sum = tf.summary.merge(
            [self.z_summary, self.G_summary, self.G_loss_summary])
        self.total_d_sum = tf.summary.merge(
            [self.z_summary, self.D_loss_summary])
        self.sess = sess

        if checkpoint_directory is not None and load:
            self.load_model(checkpoint_directory)
            return
        elif checkpoint_directory is not None:
            self.load_model(checkpoint_directory)
            z = np.random.uniform(
                -1, 1, [self.batch_size, self.z_dims]).astype(np.float32)
            samples = sess.run(self.G_sample, {self.z: z})
            fig = self.show_images(samples[0:3])
            plt.show()
            print()
            self.save_images(samples)
            return

        log_file = "./logs/" + name
        os.makedirs(os.path.dirname(log_file), exist_ok=True)
        self.writer = tf.summary.FileWriter(log_file, sess.graph)
        sess.run(tf.global_variables_initializer())

        counter = 1
        l_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dims]).astype(
            np.float32)
        r_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dims]).astype(
            np.float32)
        z = self.linear_interpolation(l_z, r_z)
        for epoch in range(0, num_epochs):
            batches = GenreDataset(self.genre, self.batch_size)
            num_batches = batches.num_batches() + 1
            progbar = Progbar(target=num_batches)
            print('Epoch #{0} out of {1}: '.format(epoch, num_epochs))
            if epoch % show_every == 0:
                samples = sess.run(self.G_sample, {self.z: z})
                fig = self.show_images(samples[0:3], True)
                plt.show()
                print()

                ex = batches.get_batch(0)
                idx = np.random.randint(0, ex.shape[0])
                fig = self.show_images(ex[idx:idx + 3], True)
                plt.show()
                print()
            # for batch, minibatch in enumerate(batches):
            for batch_idx in range(0, num_batches - 2):
                minibatch = batches.get_batch(batch_idx)

                _, D_loss_curr, d_summary = sess.run(
                    [self.D_train_op, self.D_loss, self.total_d_sum], {
                        self.images: minibatch,
                        self.z: z
                    })

                _, G_loss_curr, g_summary = sess.run(
                    [self.G_train_op, self.G_loss, self.total_g_sum],
                    {self.z: z})

                self.writer.add_summary(d_summary, counter)
                self.writer.add_summary(g_summary, counter)
                counter += 1
                progbar.update(batch_idx + 1, [('D Loss', D_loss_curr),
                                               ('G Loss', G_loss_curr)])

            if epoch % print_every == 0:
                print('Epoch: {}, D: {:.4}, G:{:.4}'.format(
                    epoch, D_loss_curr, G_loss_curr))

            if epoch % 10 == 0:
                save_file = 'checkpoint/' + self.genre + '/' + name + '/'
                os.makedirs(os.path.dirname(save_file), exist_ok=True)
                self.saver.save(sess, save_file, global_step=epoch)

        print('Final images')
        samples = sess.run(self.G_sample, {self.z: z})

        fig = self.show_images(samples[:5])
        plt.show()
        print()
        self.save_images(samples)
Example #8
0
    # The main train loop
    def fit(self, train_dataset, train_label_dataset, num_epochs=10):
        print('Training Model...')
        best_valid_loss = float('inf')
        best_valid_accuracy = float('-inf')
        batches, X_raw = batch_data_nn('combined_data.pickle', 'labels.pickle', self.batch_size)
        train_batches, (X_valid, X_char_valid,  X_mask_valid, y_valid, features_valid), (X_test, X_char_test, X_mask_test, y_test, features_test), X_raw = split_batches(batches,X_raw, 0.7, 0.2)
=======
# >>>>>>> 432e3591d40b7c1b5ad82ece875213b9bdb9de0c
#
#         train_batches, (X_valid, X_char_valid,  X_mask_valid, y_valid, features_valid), (X_test, X_char_test, X_mask_test, y_test, features_test), X_raw_test = split_batches(batches, X_raw, 0.7, 0.2)
        print(X_raw_test)
        print("Starting epochs.")
        counter = 0
        for epoch in range(1, num_epochs+1):
            progbar = Progbar(target = len(train_batches)) 

            print('Epoch #{0} out of {1}: '.format(epoch, num_epochs))
            for batch, (X_batch, X_char_batch, mask_batch, y_batch, features) in enumerate(train_batches):
                train_loss, _, train_metrics, loss_summary, metrics_summary = self.session.run((self.loss, self.train_op, self.metrics, self.loss_summary, self.metrics_summary), {
                    self.X_placeholder : X_batch, 
                    self.y_placeholder: y_batch, 
                    self.X_char_placeholder : X_char_batch, 
                    self.features: features,
                    self.X_mask_placeholder: mask_batch, 
                    self.is_training : True,
                })
                self.writer.add_summary(loss_summary, counter)
                self.writer.add_summary(metrics_summary, counter)
                progbar.update(batch+1, [('Train Loss', train_loss), ('Accuracy', train_metrics)])
                counter += 1