Exemple #1
0
def deprocess_coco(processed_dir, processed_name, targets_name, lim):
    Xf = os.path.join(processed_dir, processed_name)
    X = np.load(Xf)
    yf = os.path.join(processed_dir, targets_name)
    y = np.load(yf)
    print('X.shape:%s' % str(X.shape))
    print('y.shape:%s' % str(y.shape))
    H_W = int(math.sqrt(y.shape[1]))
    y = y.reshape(
        (-1, H_W, H_W)
    )  # note this assumes that targets (small, greyscaled images) are square in shape
    print('y.shape:%s' % str(y.shape))
    for i, proc_img in enumerate(X):
        # print('processed_img.shape:%s' % str(proc_img.shape))
        fig, ax = plt.subplots(1, 3)
        # ax[0].imshow((channel_first2last(bgr2rgb(proc_img))))
        ax[0].imshow((channel_first2last(proc_img)))
        deproc_img = deprocess_image(proc_img)
        # nqe = deproc_img[deproc_img != channel_first2last(bgr2rgb(proc_img))]
        # print('nqe:')
        # print(nqe.shape)
        ax[1].imshow(deproc_img)
        target = y[i]
        ax[2].imshow(target, cmap='Greys_r')
        plt.show()
Exemple #2
0
    def predict(self, **kwargs):
        """Use a trained model for prediction."""
        # unpack args:
        X = kwargs.get('X', None)
        RESTORE_PATH = kwargs.get('restore_path', None)
        if RESTORE_PATH is None:
            input(
                'Error: no RESTORE_PATH has been specified. Randomly initialized model should not be used for prediciton.'
                '\nPress enter to continue anyways:')

        # Tell TensorFlow that the model will be built into the default Graph.
        with tf.Graph().as_default():
            # Generate placeholders for the input images and the gold-standard class labels
            inputs_pl = tf.placeholder(tf.float32, [
                None, self.dim_img_int, self.dim_img_int, self.nb_channels_int
            ])

            # Build a Graph that computes encodings
            preds = self._encode(inputs_pl)

            # Create a saver for restoring the model from the latest/best training checkpoints.
            saver = tf.train.Saver()

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            if RESTORE_PATH is not None:
                # checkpoint_dir = os.path.dirname(RESTORE_PATH)
                # checkpoint_name = os.path.basename(RESTORE_PATH)
                # if checkpoint_name not in {'', None}:
                #     print('checkpoint_dir')
                #     print(checkpoint_dir)
                #     print('checkpoint_name')
                #     print(checkpoint_name)
                #     checkpoint_path = tf.train.get_checkpoint_state(checkpoint_dir=checkpoint_dir, latest_filename=checkpoint_name)
                # else:
                #     checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir)
                #     print('checkpoint_path')
                #     print(checkpoint_path)
                # saver.restore(sess, checkpoint_path)
                # print('model restored from checkpoint file: %s' % str(checkpoint_path))
                saver.restore(sess, RESTORE_PATH)
                print('model restored from checkpoint file: %s' %
                      str(RESTORE_PATH))
            else:
                print(
                    'No RESTORE_PATH specified, so initializing the model with random weights'
                )
                # Add the variable initializer Op.
                init = tf.initialize_all_variables()
                # Run the Op to initialize the variables.
                sess.run(init)

            # Do Prediction:
            start_time = time.time()

            # Fill a feed dictionary with the actual set of images and labels
            # for this particular training step.
            feed_dict = {inputs_pl: X}
            nb_samples = X.shape[0]

            pred_vals = sess.run(
                [preds], feed_dict=feed_dict
            )[0]  # [0] since sess.run([preds]) returns a list of len 1 in this case

            duration = time.time() - start_time

            # Print status to stdout.
            print('Prediction took: %f for %d samples  --> %f per sample' %
                  (duration, nb_samples, (duration / nb_samples)))

            print('pred_vals.shape')
            print(type(pred_vals))
            print(pred_vals.shape)
            print(pred_vals)
            print('...pred_vals...(during prediction)...')

            for samp_num in range(X.shape[0]):
                img = X[samp_num]
                # print('X.shape')
                # print(X.shape)
                # print(X[0,0,0:5])
                # input('check....')
                scores = pred_vals[samp_num]
                probs = self.softmax(scores)
                class_pred_idx = np.argmax(probs)
                class_pred = idx2label[class_pred_idx]
                plt.imshow(deprocess_image(img))
                txt = 'predicted class dist: %s\n' \
                      'predicted class: %s\n' \
                      % (str(probs), class_pred)
                plt.text(0, 0, txt, color='b', fontsize=15, fontweight='bold')
                plt.show()
Exemple #3
0
    def train(self, **kwargs):
        """Train model for a number of steps."""
        # unpack args:
        data_train = kwargs.get('data_train', None)
        if data_train is None:
            raise ValueError(
                'data_train cannot be None. At least training data must be supplied in order to train the model.'
            )
        data_val = kwargs.get('data_val', None)
        if data_val is None:
            print('Warning: no val data has been supplied.')
        batch_size_int = self.batch_size_int
        save_summaries_every = kwargs.get('save_summaries_every', 500)
        display_every = kwargs.get('display_every', 1)
        display = kwargs.get('display', False)
        nb_to_display = kwargs.get('nb_to_display', 5)
        nb_epochs = kwargs.get('nb_epochs', 100)
        save_best_only = kwargs.get('save_best_only', 'save_all')
        lr = kwargs.get('lr', 0.001)
        l2 = kwargs.get('l2', 0.0001)
        SAVE_PATH = kwargs.get('save_path', None)
        if SAVE_PATH is None:
            print('Warning: no SAVE_PATH has been specified.')
        WEIGHTS_PATH = kwargs.get('weights_path', None)
        RESTORE_PATH = kwargs.get('restore_path', None)
        load_encoder = kwargs.get('load_encoder', True)

        # ensure batch_size is set appropriately:
        if batch_size_int is None:
            raise ValueError(
                'batch_size must be specified in model instantiation.')

        # Tell TensorFlow that the model will be built into the default Graph.
        with tf.Graph().as_default():
            # Generate placeholders for the input images and the gold-standard labels
            inputs_pl = tf.placeholder(tf.float32, [
                None, self.dim_img_int, self.dim_img_int, self.nb_channels_int
            ])
            targets_pl = tf.placeholder(tf.float32, [None, self.nb_classes])

            # Create a variable to track the global step.
            global_step = tf.Variable(0, name='global_step', trainable=False)

            # Build a Graph that computes predictions
            preds = self._encode(inputs_pl)

            # Add to the Graph the Ops for loss calculation.
            loss = self._loss(predictions=preds, targets=targets_pl)

            # Add to the Graph the Ops that calculate and apply gradients.
            train_op = self._training(total_loss=loss,
                                      lr=lr,
                                      global_step=global_step)

            # Add the Op to compute the avg cross-entropy loss for evaluation purposes.
            eval_correct = self._evaluate(predictions=preds,
                                          targets=targets_pl)

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.merge_all_summaries()

            # Create a saver for writing training checkpoints.
            saver = tf.train.Saver(max_to_keep=10)

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            if SAVE_PATH is not None:
                summary_writer = tf.train.SummaryWriter(SAVE_PATH, sess.graph)
            else:
                print(
                    'WARNING: SAVE_PATH is not specified...cannot save model file'
                )

            if RESTORE_PATH not in {None, ''}:
                # checkpoint_dir = os.path.dirname(RESTORE_PATH)
                # checkpoint_name = os.path.basename(RESTORE_PATH)
                # if checkpoint_name not in {'', None}:
                #     checkpoint_path = tf.train.get_checkpoint_state(checkpoint_dir=checkpoint_dir, latest_filename=checkpoint_name)
                # else:
                #     checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir)
                # saver.restore(sess, checkpoint_path)
                # print('model restored from checkpoint file: %s' % str(checkpoint_path))
                saver.restore(sess, RESTORE_PATH)
                print('model restored from checkpoint file: %s' %
                      str(RESTORE_PATH))
            else:
                print(
                    'No RESTORE_PATH specified, so initializing the model with random weights for training...'
                )
                # Add the variable initializer Op.
                init = tf.initialize_all_variables()
                # Run the Op to initialize the variables.
                sess.run(init)

            # load pretrained weights if desired:
            if WEIGHTS_PATH not in {None, ''} and sess is not None:
                self.load_weights(WEIGHTS_PATH,
                                  sess,
                                  load_encoder=load_encoder)

            # Start the training loop: train for nb_epochs, where each epoch iterates over the entire training set once.
            history = [
            ]  # list for saving train_acc and val_acc upon evaluation after each epoch ends
            train_acc_best = 0.
            val_acc_best = 0.
            batch_tot = 0
            # START ALL EPOCHS
            for epoch_num in range(nb_epochs):
                # START ONE EPOCH
                print('starting epoch %d / %d' % (epoch_num + 1, nb_epochs))
                nb_batches_per_epoch = (data_train.nb_samples //
                                        batch_size_int) + 1
                batch_num = 0
                end_of_epoch = False
                epoch_tot_loss = 0.0
                while not end_of_epoch:
                    # iterate over all the training data once, batch by batch:
                    batch_start_time = time.time()
                    batch_num += 1
                    batch_tot += 1
                    # Fill a feed dictionary with the actual set of images and labels
                    # for this particular training step.
                    next_batch = data_train.fill_feed_dict(
                        inputs_pl, targets_pl, batch_size_int)

                    # Run one step of the model.  The return values are the activations
                    # from the `train_op` (which is discarded) and the `loss` Op.  To
                    # inspect the values of your Ops or variables, you may include them
                    # in the list passed to sess.run() and the value tensors will be
                    # returned in the tuple from the call.
                    _, loss_value, pred_vals = sess.run(
                        [train_op, loss, preds], feed_dict=next_batch)
                    batch_duration = time.time() - batch_start_time
                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'

                    epoch_tot_loss += loss_value
                    epoch_avg_loss = epoch_tot_loss / batch_num
                    # Print status to stdout.
                    print(
                        '\tbatch_num %d / %d : batch_loss = %.3f \tepoch_avg_loss = %.3f \t(%.3f sec)'
                        % (batch_num, nb_batches_per_epoch, loss_value,
                           epoch_avg_loss, batch_duration))

                    # Write the summaries and print an overview fairly often.
                    if batch_num % save_summaries_every == 0:  # 100
                        # Print status to stdout.
                        # Update the events file.
                        summary_str = sess.run(summary, feed_dict=next_batch)
                        if SAVE_PATH is not None:
                            summary_writer.add_summary(summary_str,
                                                       global_step=batch_tot)
                            summary_writer.flush()
                    end_of_epoch = data_train.end_of_epoch()
                    if end_of_epoch and epoch_num % display_every == 0:
                        print('pred_vals')
                        pred_vals = pred_vals[0:nb_to_display]
                        print(pred_vals.shape)
                        print(pred_vals)
                        cur_batch = data_train.get_cur_batch()
                        if cur_batch['X'].shape[0] < nb_to_display:
                            nb_to_display = cur_batch['X'].shape[0]
                        for samp_num in range(nb_to_display):
                            img = cur_batch['X'][samp_num]
                            class_true_idx = np.argmax(
                                cur_batch['y'][samp_num])
                            class_true = idx2label[class_true_idx]
                            scores = pred_vals[samp_num]
                            probs = self.softmax(scores)
                            class_pred_idx = np.argmax(probs)
                            class_pred = idx2label[class_pred_idx]
                            txt = '\tpredicted class dist: %s\n' \
                                  '\tpredicted class: \t%s\n' \
                                  '\ttrue class: \t\t%s' % (str(probs), class_pred, class_true)
                            print(txt)
                            if display == True:
                                fig, ax = plt.subplots(figsize=(10, 10),
                                                       nrows=1,
                                                       ncols=1)
                                ax.imshow(deprocess_image(img))
                                ax.text(0,
                                        0,
                                        txt,
                                        color='r',
                                        fontsize=15,
                                        fontweight='bold')
                                plt.show()
                    # END ONE EPOCH

                # After epoch:
                #   (1) evaluate the model
                #   (2) save a checkpoint (possibly only if the val accuracy has improved)

                # (1a) Evaluate against the training set.
                print('Training Data Eval:')
                new_best_train = False
                train_acc = self.evaluate(sess,
                                          eval_correct,
                                          inputs_pl,
                                          targets_pl,
                                          data_train,
                                          batch_size_int,
                                          lim=100)
                if train_acc > train_acc_best:
                    train_acc_best = train_acc
                    new_best_train = True
                # (1b) Evaluate against the validation set if val data was provided.
                if data_val is not None:
                    print('Validation Data Eval:')
                    val_acc = self.evaluate(sess,
                                            eval_correct,
                                            inputs_pl,
                                            targets_pl,
                                            data_val,
                                            batch_size_int,
                                            lim=100)
                    new_best_val = False
                    if val_acc > val_acc_best:
                        val_acc_best = val_acc
                        new_best_val = True
                    history.append({
                        'train_acc': train_acc,
                        'val_acc': val_acc
                    })
                else:
                    history.append({'train_acc': train_acc})

                # (2) save checkpoint file
                if save_best_only == 'save_best_train':
                    if new_best_train:
                        checkpoint_file = os.path.join(
                            SAVE_PATH, '%s_checkpoint' % self.name)
                        print(
                            'new_best_train_acc: %f \tsaving checkpoint to file: %s'
                            % (train_acc_best, str(checkpoint_file)))
                        saver.save(sess,
                                   checkpoint_file,
                                   global_step=epoch_num)
                elif save_best_only == 'save_best_val' and data_val is not None:
                    if new_best_val:
                        checkpoint_file = os.path.join(
                            SAVE_PATH, '%s_checkpoint' % self.name)
                        print(
                            'new_best_val_acc: %f \tsaving checkpoint to file: %s'
                            % (val_acc_best, str(checkpoint_file)))
                        saver.save(sess,
                                   checkpoint_file,
                                   global_step=epoch_num)
                else:
                    checkpoint_file = os.path.join(SAVE_PATH,
                                                   '%s_checkpoint' % self.name)
                    print(
                        'train_acc: %f \t val_acc: %f \tsaving checkpoint to file: %s'
                        % (train_acc, val_acc, str(checkpoint_file)))
                    saver.save(sess, checkpoint_file, global_step=epoch_num)
            # END ALL EPOCHS
        return history, train_acc_best, val_acc_best
Exemple #4
0
def preprocess_coco(raw_dir,
                    save_dir,
                    proc_shape,
                    target_H,
                    train_ratio,
                    val_ratio,
                    test_ratio,
                    random_split=False,
                    random_seed=13,
                    lim=-1,
                    chunk_size=5000,
                    targets3d=False):
    start_time = time.time()
    raw_files = [
        os.path.join(raw_dir, f) for f in os.listdir(raw_dir)
        if (os.path.isfile(os.path.join(raw_dir, f)) and f.lower().endswith((
            '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.gif', '.bmp')))
    ]
    if lim is None:
        lim = -1
    if lim > 0:
        raw_files = raw_files[0:lim]
    # if there are more than 'chunk_size' raw files, we split the data into chunks of chunk_size to reduce the memory footprint
    # and storage space needed to save each training/val/test file on disk
    raw_files_l = list()
    for start in range(0, len(raw_files), chunk_size):
        end = start + chunk_size
        raw_files_l.append(raw_files[start:end])
    for chunk_num, raw_files in enumerate(raw_files_l):
        print('raw_files')
        print(raw_files)
        processed = list()
        targets = list()
        for r, raw_file in enumerate(raw_files):
            if r % 10 == 0:
                print('chunk:%d\tpreprocessing file %d' % (chunk_num, r))
            raw_img = imread(raw_file)
            img = preprocess_image(raw_img, proc_shape=proc_shape)
            if img is None:
                continue
            deproc = deprocess_image(img)
            deproc_ratio = deproc.shape[1] / deproc.shape[0]  # ratio of W:H
            if targets3d:
                targ = transform.resize(deproc, output_shape=(3, 112, 112))
                targets.append(targ)
            else:
                small_shape = (target_H, int(target_H * deproc_ratio))
                small_gray = downsample_image(deproc,
                                              new_size=small_shape,
                                              grey=True)
                targets.append(small_gray.flatten())
            # print('(caller) small_gray.shape:%s' % str(small_gray.shape))
            # print('(caller) img.shape:%s' % str(img.shape))
            processed.append(img)
        processed = np.asarray(processed)
        targets = np.asarray(targets)
        print('chunk:%d\tprocessed.shape:%s' %
              (chunk_num, str(processed.shape)))
        print('chunk:%d\ttargets.shape:%s' % (chunk_num, str(targets.shape)))
        X_train, X_val, X_test = split_data(processed,
                                            train_ratio=train_ratio,
                                            val_ratio=val_ratio,
                                            test_ratio=test_ratio,
                                            random=random_split,
                                            random_seed=random_seed)
        y_train, y_val, y_test = split_data(targets,
                                            train_ratio=train_ratio,
                                            val_ratio=val_ratio,
                                            test_ratio=test_ratio,
                                            random=random_split,
                                            random_seed=random_seed)
        print('chunk:%d\tX_train_%d.shape%s' %
              (chunk_num, chunk_num, str(X_train.shape)))
        print('chunk:%d\ty_train_%d.shape%s' %
              (chunk_num, chunk_num, str(y_train.shape)))
        if X_val is not None:
            print('chunk:%d\tX_val_%d.shape%s' %
                  (chunk_num, chunk_num, str(X_val.shape)))
        if y_val is not None:
            print('chunk:%d\ty_val_%d.shape%s' %
                  (chunk_num, chunk_num, str(y_val.shape)))
        if X_test is not None:
            print('chunk:%d\tX_test_%d.shape%s' %
                  (chunk_num, chunk_num, str(X_test.shape)))
        if y_test is not None:
            print('chunk:%d\ty_test_%d.shape%s' %
                  (chunk_num, chunk_num, str(y_test.shape)))

        save_data(X_train,
                  save_dir,
                  'X_train_%d' % chunk_num,
                  save_format='npy')
        if X_val is not None:
            save_data(X_val,
                      save_dir,
                      'X_val_%d' % chunk_num,
                      save_format='npy')
        if X_test is not None:
            save_data(X_test,
                      save_dir,
                      'X_test_%d' % chunk_num,
                      save_format='npy')
        save_data(y_train,
                  save_dir,
                  'y_train_%d' % chunk_num,
                  save_format='npy')
        if y_val is not None:
            save_data(y_val,
                      save_dir,
                      'y_val_%d' % chunk_num,
                      save_format='npy')
        if y_test is not None:
            save_data(y_test,
                      save_dir,
                      'y_test_%d' % chunk_num,
                      save_format='npy')
        del raw_files
        del processed
        del targets
    end_time = time.time()
    duration = (end_time - start_time) / 60.
    print('total duration: %f mins' % (duration))
Exemple #5
0
    def predict(self, **kwargs):
        """Use a trained model for prediction."""
        # unpack args:
        X = kwargs.get('X', None)
        batch_size_int = kwargs.get('batch_size', None)
        LOAD_PATH = kwargs.get('load_path', None)
        if LOAD_PATH is None:
            input(
                'Error: no LOAD_PATH has been specified. Randomly initialized model should not be used for prediciton.'
                '\nPress enter to continue anyways:')

        # Tell TensorFlow that the model will be built into the default Graph.
        with tf.Graph().as_default():
            # Generate placeholders for the input images and the gold-standard class labels
            inputs_pl = tf.placeholder(tf.float32, [
                None, self.dim_img_int, self.dim_img_int, self.nb_channels_int
            ])

            # Build a Graph that computes encodings
            preds = self._encode(inputs_pl)

            # Create a saver for restoring the model from the latest/best training checkpoints.
            saver = tf.train.Saver()

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            if LOAD_PATH is not None:
                checkpoint_path = tf.train.latest_checkpoint(LOAD_PATH)
                saver.restore(sess, checkpoint_path)
                print('model restored from checkpoint file: %s' %
                      str(checkpoint_path))
            else:
                print(
                    'No LOAD_PATH specified, so initializing the model with random weights'
                )
                # Add the variable initializer Op.
                init = tf.initialize_all_variables()
                # Run the Op to initialize the variables.
                sess.run(init)

            # Do Prediction:
            start_time = time.time()

            # Fill a feed dictionary with the actual set of images and labels
            # for this particular training step.
            feed_dict = {inputs_pl: X}

            pred_vals = sess.run([preds], feed_dict=feed_dict)
            duration = time.time() - start_time

            # Print status to stdout.
            print(
                'Prediction took: %f for batch_size of: %d  --> %f per example'
                % (duration, self.batch_size_int,
                   (duration / self.batch_size_int)))

            print('pred_vals.shape')
            print(type(pred_vals))
            print(pred_vals.shape)
            print(pred_vals)
            input('...pred_vals...(during prediction)...')

            for samp_num in range(X.shape[0]):
                img = X[samp_num]
                scores = pred_vals[samp_num]
                probs = self.softmax(scores)
                class_pred_idx = np.argmax(probs)
                class_pred = idx2label[class_pred_idx]
                fig, ax = plt.subplots(figsize=(10, 10), nrows=1, ncols=2)
                ax[0].imshow(deprocess_image(img))
                txt = 'predicted class dist: %s\n' \
                      'predicted class: %s\n' \
                      % (str(probs), class_pred)
                ax[0].text(0,
                           0,
                           txt,
                           color='b',
                           fontsize=15,
                           fontweight='bold')
                plt.show()
Exemple #6
0
    def train(self, **kwargs):
        """Train model for a number of steps."""
        # unpack args:
        data_train = kwargs.get('data_train', None)
        if data_train is None:
            raise ValueError(
                'data_train cannot be None. At least training data must be supplied in order to train the model.'
            )
        data_val = kwargs.get('data_val', None)
        if data_val is None:
            print('Warning: no val data has been supplied.')
        batch_size_int = kwargs.get('batch_size', None)
        epoch_size = kwargs.get('epoch_size', None)  # TODO
        nb_epochs = kwargs.get('nb_epochs', 100)
        max_iters = kwargs.get('max_iters', 5000)  # TODO
        lr = kwargs.get('lr', 0.001)
        l2 = kwargs.get('l2', 0.0001)
        SAVE_PATH = kwargs.get('save_path', None)
        if SAVE_PATH is None:
            print('Warning: no SAVE_PATH has been specified.')
        weights = kwargs.get('weights', None)
        load_encoder = kwargs.get('load_encoder', True)

        # ensure batch_size is set appropriately:
        if batch_size_int is None:
            if self.batch_size_int is None:
                raise ValueError(
                    'batch_size must be specified either in model instantiation or passed into this training method,'
                    'but batch_size cannnot be None in both cases.')
            batch_size_int = self.batch_size_int

        # Tell TensorFlow that the model will be built into the default Graph.
        with tf.Graph().as_default():
            # Generate placeholders for the input images and the gold-standard labels
            inputs_pl = tf.placeholder(tf.float32, [
                None, self.dim_img_int, self.dim_img_int, self.nb_channels_int
            ])
            targets_pl = tf.placeholder(tf.float32, [None, self.nb_classes])

            # Create a variable to track the global step.
            global_step = tf.Variable(0, name='global_step', trainable=False)

            # Build a Graph that computes predictions
            preds = self._encode(inputs_pl)

            # Add to the Graph the Ops for loss calculation.
            loss = self._loss(predictions=preds, targets=targets_pl)

            # Add to the Graph the Ops that calculate and apply gradients.
            train_op = self._training(total_loss=loss,
                                      lr=lr,
                                      global_step=global_step)

            # Add the Op to compute the avg cross-entropy loss for evaluation purposes.
            eval_correct = self._evaluate(predictions=preds,
                                          targets=targets_pl)

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.merge_all_summaries()

            # Add the variable initializer Op.
            init = tf.initialize_all_variables()

            # Create a saver for writing training checkpoints.
            saver = tf.train.Saver()

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            if SAVE_PATH is not None:
                summary_writer = tf.train.SummaryWriter(SAVE_PATH, sess.graph)
            else:
                print(
                    'WARNING: SAVE_PATH is not specified...cannot save model file'
                )

            # And then after everything is built:
            # Run the Op to initialize the variables.
            sess.run(init)

            # load pretrained weights if desired:
            if weights is not None and sess is not None:
                self.load_weights(weights, sess, load_encoder=load_encoder)

            steps_per_epoch = data_train.nb_samples // batch_size_int
            # TODO: make the training loop a nested for loop that loops over all training data (in inner for loop) nb_epochs number of times (outter for loop)
            # Start the training loop.
            for step in range(max_iters):
                start_time = time.time()

                # Fill a feed dictionary with the actual set of images and labels
                # for this particular training step.
                feed_dict = data_train.fill_feed_dict(inputs_pl, targets_pl,
                                                      batch_size_int)

                # Run one step of the model.  The return values are the activations
                # from the `train_op` (which is discarded) and the `loss` Op.  To
                # inspect the values of your Ops or variables, you may include them
                # in the list passed to sess.run() and the value tensors will be
                # returned in the tuple from the call.
                _, loss_value, pred_vals = sess.run([train_op, loss, preds],
                                                    feed_dict=feed_dict)
                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                # Print status to stdout.
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, loss_value, duration))

                # Write the summaries and print an overview fairly often.
                # if step > 0 and step % 5 == 0: # 100
                if step % 50 == 0:  # 100
                    # Print status to stdout.
                    # print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
                    # Update the events file.
                    print('pred_vals.shape')
                    print(type(pred_vals))
                    print(pred_vals.shape)
                    print(pred_vals)
                    cur_batch = data_train.get_cur_batch()
                    for samp_num in range(cur_batch['X'].shape[0]):
                        img = cur_batch['X'][samp_num]
                        class_true_idx = np.argmax(cur_batch['y'][samp_num])
                        class_true = idx2label[class_true_idx]
                        scores = pred_vals[samp_num]
                        probs = self.softmax(scores)
                        class_pred_idx = np.argmax(probs)
                        class_pred = idx2label[class_pred_idx]
                        fig, ax = plt.subplots(figsize=(10, 10),
                                               nrows=1,
                                               ncols=1)
                        ax.imshow(deprocess_image(img))
                        txt = 'predicted class dist: %s\n' \
                              'predicted class: \t%s\n' \
                              'true class: \t\t%s' % (str(probs), class_pred, class_true)
                        ax.text(0,
                                0,
                                txt,
                                color='r',
                                fontsize=15,
                                fontweight='bold')
                        plt.show()

                    summary_str = sess.run(summary, feed_dict=feed_dict)
                    if SAVE_PATH is not None:
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                # Save a checkpoint and evaluate the model periodically.
                if (step + 1) % 10 == 0 or (step + 1) == max_iters:  # 1000
                    checkpoint_file = os.path.join(SAVE_PATH,
                                                   '%s_checkpoint' % self.name)
                    saver.save(sess, checkpoint_file, global_step=step)

                    # Evaluate against the training set.
                    print('Training Data Eval:')
                    self.evaluate(sess,
                                  eval_correct,
                                  inputs_pl,
                                  targets_pl,
                                  data_train,
                                  batch_size_int,
                                  lim=100)

                    # Evaluate against the validation set.
                    print('Validation Data Eval:')
                    self.evaluate(sess,
                                  eval_correct,
                                  inputs_pl,
                                  targets_pl,
                                  data_val,
                                  batch_size_int,
                                  lim=100)