示例#1
0
class Train(BasicTrain):
    """
    Trainer class
    """

    def __init__(self, args, sess, train_model, test_model):
        """
        Call the constructor of the base class
        init summaries
        init loading data
        :param args:
        :param sess:
        :param model:
        :return:
        """
        super().__init__(args, sess, train_model, test_model)
        ##################################################################################
        # Init summaries

        # Summary variables
        self.scalar_summary_tags = ['mean_iou_on_val',
                                    'train-loss-per-epoch', 'val-loss-per-epoch',
                                    'train-acc-per-epoch', 'val-acc-per-epoch']
        self.images_summary_tags = [
            ('train_prediction_sample', [None, self.params.img_height, self.params.img_width * 2, 3]),
            ('val_prediction_sample', [None, self.params.img_height, self.params.img_width * 2, 3])]
        self.summary_tags = []
        self.summary_placeholders = {}
        self.summary_ops = {}
        # init summaries and it's operators
        self.init_summaries()
        # Create summary writer
        self.summary_writer = tf.summary.FileWriter(self.args.summary_dir, self.sess.graph)
        ##################################################################################
        # Init load data and generator
        self.generator = None
        if self.args.data_mode == "experiment_tfdata":
            self.data_session = None
            self.train_next_batch, self.train_data_len = self.init_tfdata(self.args.batch_size, self.args.abs_data_dir,
                                                                          (self.args.img_height, self.args.img_width),
                                                                          mode='train')
            self.num_iterations_training_per_epoch = self.train_data_len // self.args.batch_size
            self.generator = self.train_tfdata_generator
        elif self.args.data_mode == "experiment_h5":
            self.train_data = None
            self.train_data_len = None
            self.val_data = None
            self.val_data_len = None
            self.num_iterations_training_per_epoch = None
            self.num_iterations_validation_per_epoch = None
            self.load_train_data_h5()
            self.generator = self.train_h5_generator
        elif self.args.data_mode == "experiment_v2":
            self.targets_resize = self.args.targets_resize
            self.train_data = None
            self.train_data_len = None
            self.val_data = None
            self.val_data_len = None
            self.num_iterations_training_per_epoch = None
            self.num_iterations_validation_per_epoch = None
            self.load_train_data(v2=True)
            self.generator = self.train_generator
        elif self.args.data_mode == "experiment":
            self.train_data = None
            self.train_data_len = None
            self.val_data = None
            self.val_data_len = None
            self.num_iterations_training_per_epoch = None
            self.num_iterations_validation_per_epoch = None
            self.load_train_data()
            self.generator = self.train_generator
        elif self.args.data_mode == "test_tfdata":
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_val_data()
            self.generator = self.test_tfdata_generator
        elif self.args.data_mode == "test":
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_val_data()
            self.generator = self.test_generator
        elif self.args.data_mode == "test_eval":
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.names_mapper = None
            self.load_test_data()
            self.generator = self.test_generator
        elif self.args.data_mode == "test_v2":
            self.targets_resize = self.args.targets_resize
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_val_data(v2=True)
            self.generator = self.test_generator
        elif self.args.data_mode == "video":
            self.args.data_mode = "test"
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_vid_data()
            self.generator = self.test_generator
        elif self.args.data_mode == "debug":
            print("Debugging photo loading..")
            #            self.debug_x= misc.imread('/data/menna/cityscapes/leftImg8bit/val/lindau/lindau_000048_000019_leftImg8bit.png')
            #            self.debug_y= misc.imread('/data/menna/cityscapes/gtFine/val/lindau/lindau_000048_000019_gtFine_labelIds.png')
            #            self.debug_x= np.expand_dims(misc.imresize(self.debug_x, (512,1024)), axis=0)
            #            self.debug_y= np.expand_dims(misc.imresize(self.debug_y, (512,1024)), axis=0)
            self.debug_x = np.load('data/debug/debug_x.npy')
            self.debug_y = np.load('data/debug/debug_y.npy')
            print("Debugging photo loaded")
        else:
            print("ERROR Please select a proper data_mode BYE")
            exit(-1)
        ##################################################################################
        # Init metrics class
        self.metrics = Metrics(self.args.num_classes)
        # Init reporter class
        if self.args.mode == 'train' or 'overfit':
            self.reporter = Reporter(self.args.out_dir + 'report_train.json', self.args)
        elif self.args.mode == 'test':
            self.reporter = Reporter(self.args.out_dir + 'report_test.json', self.args)
            ##################################################################################

    def crop(self):
        sh = self.val_data['X'].shape
        temp_val_data = {'X': np.zeros((sh[0] * 2, sh[1], sh[2] // 2, sh[3]), self.val_data['X'].dtype),
                         'Y': np.zeros((sh[0] * 2, sh[1], sh[2] // 2), self.val_data['Y'].dtype)}
        for i in range(sh[0]):
            temp_val_data['X'][i * 2, :, :, :] = self.val_data['X'][i, :, :sh[2] // 2, :]
            temp_val_data['X'][i * 2 + 1, :, :, :] = self.val_data['X'][i, :, sh[2] // 2:, :]
            temp_val_data['Y'][i * 2, :, :] = self.val_data['Y'][i, :, :sh[2] // 2]
            temp_val_data['Y'][i * 2 + 1, :, :] = self.val_data['Y'][i, :, sh[2] // 2:]

        self.val_data = temp_val_data

    def init_tfdata(self, batch_size, main_dir, resize_shape, mode='train'):
        self.data_session = tf.Session()
        print("Creating the iterator for training data")
        with tf.device('/cpu:0'):
            segdl = SegDataLoader(main_dir, batch_size, (resize_shape[0], resize_shape[1]), resize_shape,
                                  # * 2), resize_shape,
                                  'data/cityscapes_tfdata/train.txt')
            iterator = Iterator.from_structure(segdl.data_tr.output_types, segdl.data_tr.output_shapes)
            next_batch = iterator.get_next()

            self.init_op = iterator.make_initializer(segdl.data_tr)
            self.data_session.run(self.init_op)

        print("Loading Validation data in memoryfor faster training..")
        self.val_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
                         'Y': np.load(self.args.data_dir + "Y_val.npy")}
        # self.crop()
        # import cv2
        # cv2.imshow('crop1', self.val_data['X'][0,:,:,:])
        # cv2.imshow('crop2', self.val_data['X'][1,:,:,:])
        # cv2.imshow('seg1', self.val_data['Y'][0,:,:])
        # cv2.imshow('seg2', self.val_data['Y'][1,:,:])
        # cv2.waitKey()

        self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
        #        self.num_iterations_validation_per_epoch = (
        #                                                       self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        self.num_iterations_validation_per_epoch = self.val_data_len // self.args.batch_size

        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " + str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " + str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

        return next_batch, segdl.data_len

    @timeit
    def load_overfit_data(self):
        print("Loading data..")
        self.train_data = {'X': np.load(self.args.data_dir + "X_train.npy"),
                           'Y': np.load(self.args.data_dir + "Y_train.npy")}
        self.train_data_len = self.train_data['X'].shape[0] - self.train_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_training_per_epoch = (
                                                         self.train_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Train-shape-x -- " + str(self.train_data['X'].shape))
        print("Train-shape-y -- " + str(self.train_data['Y'].shape))
        print("Num of iterations in one epoch -- " + str(self.num_iterations_training_per_epoch))
        print("Overfitting data is loaded")

        print("Loading Validation data..")
        self.val_data = self.train_data
        self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_validation_per_epoch = (
                                                           self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " + str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " + str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

    def overfit_generator(self):
        start = 0
        new_epoch_flag = True
        idx = None
        while True:
            # init index array if it is a new_epoch
            if new_epoch_flag:
                if self.args.shuffle:
                    idx = np.random.choice(self.train_data_len, self.train_data_len, replace=False)
                else:
                    idx = np.arange(self.train_data_len)
                new_epoch_flag = False

            # select the mini_batches
            mask = idx[start:start + self.args.batch_size]
            x_batch = self.train_data['X'][mask]
            y_batch = self.train_data['Y'][mask]

            start += self.args.batch_size
            if start >= self.train_data_len:
                start = 0
                new_epoch_flag = True

            yield x_batch, y_batch

    def init_summaries(self):
        """
        Create the summary part of the graph
        :return:
        """
        with tf.variable_scope('train-summary-per-epoch'):
            for tag in self.scalar_summary_tags:
                self.summary_tags += tag
                self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag)
                self.summary_ops[tag] = tf.summary.scalar(tag, self.summary_placeholders[tag])
            for tag, shape in self.images_summary_tags:
                self.summary_tags += tag
                self.summary_placeholders[tag] = tf.placeholder('float32', shape, name=tag)
                self.summary_ops[tag] = tf.summary.image(tag, self.summary_placeholders[tag], max_outputs=10)

    def add_summary(self, step, summaries_dict=None, summaries_merged=None):
        """
        Add the summaries to tensorboard
        :param step:
        :param summaries_dict:
        :param summaries_merged:
        :return:
        """
        if summaries_dict is not None:
            summary_list = self.sess.run([self.summary_ops[tag] for tag in summaries_dict.keys()],
                                         {self.summary_placeholders[tag]: value for tag, value in
                                          summaries_dict.items()})
            for summary in summary_list:
                self.summary_writer.add_summary(summary, step)
        if summaries_merged is not None:
            self.summary_writer.add_summary(summaries_merged, step)

    @timeit
    def load_train_data(self, v2=False):
        print("Loading Training data..")
        self.train_data = {'X': np.load(self.args.data_dir + "X_train.npy"),
                           'Y': np.load(self.args.data_dir + "Y_train.npy")}
        self.train_data = self.resize(self.train_data)

        if v2:
            out_shape = (self.train_data['Y'].shape[1] // self.targets_resize,
                         self.train_data['Y'].shape[2] // self.targets_resize)
            yy = np.zeros((self.train_data['Y'].shape[0], out_shape[0], out_shape[1]), dtype=self.train_data['Y'].dtype)
            for y in range(self.train_data['Y'].shape[0]):
                yy[y, ...] = misc.imresize(self.train_data['Y'][y, ...], out_shape, interp='nearest')
            self.train_data['Y'] = yy
        self.train_data_len = self.train_data['X'].shape[0]

        self.num_iterations_training_per_epoch = (
                                                         self.train_data_len + self.args.batch_size - 1) // self.args.batch_size

        print("Train-shape-x -- " + str(self.train_data['X'].shape) + " " + str(self.train_data_len))
        print("Train-shape-y -- " + str(self.train_data['Y'].shape))
        print("Num of iterations on training data in one epoch -- " + str(self.num_iterations_training_per_epoch))
        print("Training data is loaded")

        print("Loading Validation data..")
        self.val_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
                         'Y': np.load(self.args.data_dir + "Y_val.npy")}
        self.val_data['Y_large'] = self.val_data['Y']
        if v2:
            out_shape = (self.val_data['Y'].shape[1] // self.targets_resize,
                         self.val_data['Y'].shape[2] // self.targets_resize)
            yy = np.zeros((self.val_data['Y'].shape[0], out_shape[0], out_shape[1]), dtype=self.train_data['Y'].dtype)
            for y in range(self.val_data['Y'].shape[0]):
                yy[y, ...] = misc.imresize(self.val_data['Y'][y, ...], out_shape, interp='nearest')
            self.val_data['Y'] = yy

        self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_validation_per_epoch = (
                                                           self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " + str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " + str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

    @timeit
    def load_train_data_h5(self):
        print("Loading Training data..")
        self.train_data = h5py.File(self.args.data_dir + self.args.h5_train_file, 'r')
        self.train_data_len = self.args.h5_train_len
        self.num_iterations_training_per_epoch = (
                                                         self.train_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Train-shape-x -- " + str(self.train_data['X'].shape) + " " + str(self.train_data_len))
        print("Train-shape-y -- " + str(self.train_data['Y'].shape))
        print("Num of iterations on training data in one epoch -- " + str(self.num_iterations_training_per_epoch))
        print("Training data is loaded")

        print("Loading Validation data..")
        self.val_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
                         'Y': np.load(self.args.data_dir + "Y_val.npy")}
        self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_validation_per_epoch = (
                                                           self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " + str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " + str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

    @timeit
    def load_vid_data(self):
        print("Loading Video data..")
        self.test_data = {'X': np.load(self.args.data_dir + "X_vid.npy")}
        self.test_data['Y'] = np.zeros(self.test_data['X'].shape[:3])
        self.test_data_len = self.test_data['X'].shape[0]
        print("Vid-shape-x -- " + str(self.test_data['X'].shape))
        print("Vid-shape-y -- " + str(self.test_data['Y'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Video data is loaded")

    @timeit
    def load_val_data(self, v2=False):
        print("Loading Validation data..")
        self.test_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
                          'Y': np.load(self.args.data_dir + "Y_val.npy")}
        self.test_data = self.resize(self.test_data)
        self.test_data['Y_large'] = self.test_data['Y']
        if v2:
            out_shape = (self.test_data['Y'].shape[1] // self.targets_resize,
                         self.test_data['Y'].shape[2] // self.targets_resize)
            yy = np.zeros((self.test_data['Y'].shape[0], out_shape[0], out_shape[1]), dtype=self.test_data['Y'].dtype)
            for y in range(self.test_data['Y'].shape[0]):
                yy[y, ...] = misc.imresize(self.test_data['Y'][y, ...], out_shape, interp='nearest')
            self.test_data['Y'] = yy

        self.test_data_len = self.test_data['X'].shape[0] - self.test_data['X'].shape[0] % self.args.batch_size
        print("Validation-shape-x -- " + str(self.test_data['X'].shape))
        print("Validation-shape-y -- " + str(self.test_data['Y'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Validation data is loaded")

    @timeit
    def load_test_data(self):
        print("Loading Testing data..")
        self.test_data = {'X': np.load(self.args.data_dir + "X_test.npy")}
        self.names_mapper = {'X': np.load(self.args.data_dir + "xnames_test.npy"),
                             'Y': np.load(self.args.data_dir + "ynames_test.npy")}
        self.test_data_len = self.test_data['X'].shape[0] - self.test_data['X'].shape[0] % self.args.batch_size
        print("Test-shape-x -- " + str(self.test_data['X'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Test data is loaded")

    def test_generator(self):
        start = 0
        new_epoch_flag = True
        idx = None
        while True:
            # init index array if it is a new_epoch
            if new_epoch_flag:
                if self.args.shuffle:
                    idx = np.random.choice(self.test_data_len, self.test_data_len, replace=False)
                else:
                    idx = np.arange(self.test_data_len)
                new_epoch_flag = False

            # select the mini_batches
            mask = idx[start:start + self.args.batch_size]
            x_batch = self.test_data['X'][mask]
            y_batch = self.test_data['Y'][mask]

            # update start idx
            start += self.args.batch_size

            if start >= self.test_data_len:
                start = 0
                new_epoch_flag = True

            yield x_batch, y_batch

    def train_generator(self):
        start = 0
        idx = np.random.choice(self.train_data_len, self.num_iterations_training_per_epoch * self.args.batch_size,
                               replace=True)
        while True:
            # select the mini_batches
            mask = idx[start:start + self.args.batch_size]
            x_batch = self.train_data['X'][mask]
            y_batch = self.train_data['Y'][mask]

            # update start idx
            start += self.args.batch_size

            yield x_batch, y_batch

            if start >= self.train_data_len:
                return

    def train_tfdata_generator(self):
        with tf.device('/cpu:0'):
            while True:
                x_batch, y_batch = self.data_session.run(self.train_next_batch)
                yield x_batch, y_batch[:, :, :, 0]

    def train_h5_generator(self):
        start = 0
        idx = np.random.choice(self.train_data_len, self.train_data_len,
                               replace=False)
        while True:
            # select the mini_batches
            mask = idx[start:start + self.args.batch_size]
            x_batch = self.train_data['X'][sorted(mask.tolist())]
            y_batch = self.train_data['Y'][sorted(mask.tolist())]

            # update start idx
            start += self.args.batch_size

            if start >= self.train_data_len:
                return

            yield x_batch, y_batch

    def resize(self, data):
        X = []
        Y = []
        for i in range(data['X'].shape[0]):
            X.append(misc.imresize(data['X'][i, ...], (self.args.img_height, self.args.img_width)))
            Y.append(misc.imresize(data['Y'][i, ...], (self.args.img_height, self.args.img_width), 'nearest'))
        data['X'] = np.asarray(X)
        data['Y'] = np.asarray(Y)
        return data

    def train(self):
        print("Training mode will begin NOW ..")
        # curr_lr= self.model.args.learning_rate
        for cur_epoch in range(self.model.global_epoch_tensor.eval(self.sess) + 1, self.args.num_epochs + 1, 1):

            # init tqdm and get the epoch value
            tt = tqdm(self.generator(), total=self.num_iterations_training_per_epoch,
                      desc="epoch-" + str(cur_epoch) + "-")

            # init the current iterations
            cur_iteration = 0

            # init acc and loss lists
            loss_list = []
            acc_list = []

            # loop by the number of iterations
            for x_batch, y_batch in tt:

                # get the cur_it for the summary
                cur_it = self.model.global_step_tensor.eval(self.sess)

                # Feed this variables to the network
                feed_dict = {self.model.x_pl: x_batch,
                             self.model.y_pl: y_batch,
                             self.model.is_training: True
                             #                             self.model.curr_learning_rate:curr_lr
                             }

                # Run the feed forward but the last iteration finalize what you want to do
                if cur_iteration < self.num_iterations_training_per_epoch - 1:

                    # run the feed_forward
                    _, loss, acc, summaries_merged = self.sess.run(
                        [self.model.train_op, self.model.loss, self.model.accuracy, self.model.merged_summaries],
                        feed_dict=feed_dict)
                    # log loss and acc
                    loss_list += [loss]
                    acc_list += [acc]
                    # summarize
                #                    self.add_summary(cur_it, summaries_merged=summaries_merged)

                else:
                    # run the feed_forward
                    if self.args.data_mode == 'experiment_v2':
                        _, loss, acc, summaries_merged = self.sess.run(
                            [self.model.train_op, self.model.loss, self.model.accuracy,
                             self.model.merged_summaries],
                            feed_dict=feed_dict)
                    else:
                        _, loss, acc, summaries_merged, segmented_imgs = self.sess.run(
                            [self.model.train_op, self.model.loss, self.model.accuracy,
                             self.model.merged_summaries, self.model.segmented_summary],
                            feed_dict=feed_dict)

                    # log loss and acc
                    loss_list += [loss]
                    acc_list += [acc]
                    total_loss = np.mean(loss_list)
                    total_acc = np.mean(acc_list)
                    # summarize
                    summaries_dict = dict()
                    summaries_dict['train-loss-per-epoch'] = total_loss
                    summaries_dict['train-acc-per-epoch'] = total_acc

                    if self.args.data_mode != 'experiment_v2':
                        summaries_dict['train_prediction_sample'] = segmented_imgs
                    # self.add_summary(cur_it, summaries_dict=summaries_dict, summaries_merged=summaries_merged)

                    # report
                    self.reporter.report_experiment_statistics('train-acc', 'epoch-' + str(cur_epoch), str(total_acc))
                    self.reporter.report_experiment_statistics('train-loss', 'epoch-' + str(cur_epoch), str(total_loss))
                    self.reporter.finalize()

                    # Update the Global step
                    self.model.global_step_assign_op.eval(session=self.sess,
                                                          feed_dict={self.model.global_step_input: cur_it + 1})

                    # Update the Cur Epoch tensor
                    # it is the last thing because if it is interrupted it repeat this
                    self.model.global_epoch_assign_op.eval(session=self.sess,
                                                           feed_dict={self.model.global_epoch_input: cur_epoch + 1})

                    # print in console
                    tt.close()
                    print("epoch-" + str(cur_epoch) + "-" + "loss:" + str(total_loss) + "-" + " acc:" + str(total_acc)[
                                                                                                        :6])

                    # Break the loop to finalize this epoch
                    break

                # Update the Global step
                self.model.global_step_assign_op.eval(session=self.sess,
                                                      feed_dict={self.model.global_step_input: cur_it + 1})

                # update the cur_iteration
                cur_iteration += 1

            # Save the current checkpoint
            if cur_epoch % self.args.save_every == 0:
                self.save_model()

            # Test the model on validation
            if cur_epoch % self.args.test_every == 0:
                self.test_per_epoch(step=self.model.global_step_tensor.eval(self.sess),
                                    epoch=self.model.global_epoch_tensor.eval(self.sess))
        #            if cur_epoch % self.args.learning_decay_every == 0:
        #                curr_lr= curr_lr*self.args.learning_decay
        #                print('Current learning rate is ', curr_lr)

        print("Training Finished")

    def test_per_epoch(self, step, epoch):
        print("Validation at step:" + str(step) + " at epoch:" + str(epoch) + " ..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.num_iterations_validation_per_epoch), total=self.num_iterations_validation_per_epoch,
                  desc="Val-epoch-" + str(epoch) + "-")

        # init acc and loss lists
        loss_list = []
        acc_list = []
        inf_list = []

        # idx of minibatch
        idx = 0

        # reset metrics
        self.metrics.reset()

        # get the maximum iou to compare with and save the best model
        max_iou = self.model.best_iou_tensor.eval(self.sess)

        # loop by the number of iterations
        for cur_iteration in tt:
            # load minibatches
            x_batch = self.val_data['X'][idx:idx + self.args.batch_size]
            y_batch = self.val_data['Y'][idx:idx + self.args.batch_size]
            if self.args.data_mode == 'experiment_v2':
                y_batch_large = self.val_data['Y_large'][idx:idx + self.args.batch_size]

            # update idx of minibatch
            idx += self.args.batch_size

            # Feed this variables to the network
            feed_dict = {self.model.x_pl: x_batch,
                         self.model.y_pl: y_batch,
                         self.model.is_training: False
                         }

            # Run the feed forward but the last iteration finalize what you want to do
            if cur_iteration < self.num_iterations_validation_per_epoch - 1:

                start = time.time()
                # run the feed_forward

                out_argmax, loss, acc, summaries_merged = self.sess.run(
                    [self.model.out_argmax, self.model.loss, self.model.accuracy, self.model.merged_summaries],
                    feed_dict=feed_dict)

                end = time.time()
                # log loss and acc
                loss_list += [loss]
                acc_list += [acc]
                inf_list += [end - start]
                if self.args.data_mode == 'experiment_v2':
                    yy = np.zeros((out_argmax.shape[0], y_batch_large.shape[1], y_batch_large.shape[2]),
                                  dtype=np.uint32)
                    out_argmax = np.asarray(out_argmax, dtype=np.uint8)
                    for y in range(out_argmax.shape[0]):
                        yy[y, ...] = misc.imresize(out_argmax[y, ...], y_batch_large.shape[1:], interp='nearest')
                    y_batch = y_batch_large
                    out_argmax = yy

                # log metrics
                self.metrics.update_metrics_batch(out_argmax, y_batch)

            else:
                start = time.time()
                # run the feed_forward
                if self.args.data_mode == 'experiment_v2':  # Issues in concatenating gt and img with diff sizes now for segmented_imgs
                    out_argmax, acc = self.sess.run(
                        [self.test_model.out_argmax, self.test_model.accuracy],
                        feed_dict=feed_dict)
                else:
                    out_argmax, acc, segmented_imgs = self.sess.run(
                        [self.test_model.out_argmax, self.test_model.accuracy, self.test_model.segmented_summary],
                        feed_dict=feed_dict)

                end = time.time()
                # log loss and acc
                acc_list += [acc]
                inf_list += [end - start]
                # log metrics
                self.metrics.update_metrics_batch(out_argmax, y_batch)
                # mean over batches
                total_acc = np.mean(acc_list)
                mean_iou = self.metrics.compute_final_metrics(self.num_iterations_validation_per_epoch)
                mean_iou_arr = self.metrics.iou
                mean_inference = str(np.mean(inf_list)) + '-seconds'
                # summarize
                summaries_dict = dict()
                summaries_dict['val-acc-per-epoch'] = total_acc
                summaries_dict['mean_iou_on_val'] = mean_iou
                if self.args.data_mode != 'experiment_v2':  # Issues in concatenating gt and img with diff sizes now for segmented_imgs
                    summaries_dict['val_prediction_sample'] = segmented_imgs
                #                self.add_summary(step, summaries_dict=summaries_dict, summaries_merged=summaries_merged)

                # report
                self.reporter.report_experiment_statistics('validation-acc', 'epoch-' + str(epoch), str(total_acc))
                self.reporter.report_experiment_statistics('avg_inference_time_on_validation', 'epoch-' + str(epoch),
                                                           str(mean_inference))
                self.reporter.report_experiment_validation_iou('epoch-' + str(epoch), str(mean_iou), mean_iou_arr)
                self.reporter.finalize()

                # print in console
                tt.close()
                print("Val-epoch-" + str(epoch) + "-" +
                      "acc:" + str(total_acc)[:6] + "-mean_iou:" + str(mean_iou))
                print("Last_max_iou: " + str(max_iou))
                if mean_iou > max_iou:
                    print("This validation got a new best iou. so we will save this one")
                    # save the best model
                    self.save_best_model()
                    # Set the new maximum
                    self.model.best_iou_assign_op.eval(session=self.sess,
                                                       feed_dict={self.model.best_iou_input: mean_iou})
                else:
                    print("hmm not the best validation epoch :/..")
                break

                # Break the loop to finalize this epoch

    def linknet_postprocess(self, gt):
        gt2 = gt - 1
        gt2[gt == -1] = 19
        return gt2

    def test(self, pkl=False):
        print("Testing mode will begin NOW..")

        # load the best model checkpoint to test on it
        if not pkl:
            self.load_best_model()

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))
        # naming = np.load(self.args.data_dir + 'names_train.npy')

        # init acc and loss lists
        acc_list = []
        img_list = []

        # idx of image
        idx = 0

        # reset metrics
        self.metrics.reset()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            y_batch = self.test_data['Y'][idx:idx + 1]
            if self.args.data_mode == 'test_v2':
                y_batch_large = self.test_data['Y_large'][idx:idx + 1]

            idx += 1

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.test_model.x_pl_before: x_batch,
                             self.test_model.y_pl_before: y_batch,
                             self.test_model.is_training: False,
                             }
            else:
                feed_dict = {self.test_model.x_pl: x_batch,
                             self.test_model.y_pl: y_batch,
                             self.test_model.is_training: False
                             }

            # run the feed_forward
            if self.args.data_mode == 'test_v2':
                out_argmax, acc = self.sess.run(
                    [self.test_model.out_argmax, self.test_model.accuracy],
                    feed_dict=feed_dict)
            else:
                out_argmax, acc, segmented_imgs = self.sess.run(
                    [self.test_model.out_argmax, self.test_model.accuracy,
                     # self.test_model.merged_summaries, self.test_model.segmented_summary],
                     self.test_model.segmented_summary],
                    feed_dict=feed_dict)

            if self.args.data_mode == 'test_v2':
                yy = np.zeros((out_argmax.shape[0], y_batch_large.shape[1], y_batch_large.shape[2]), dtype=np.uint32)
                out_argmax = np.asarray(out_argmax, dtype=np.uint8)
                for y in range(out_argmax.shape[0]):
                    yy[y, ...] = misc.imresize(out_argmax[y, ...], y_batch_large.shape[1:], interp='nearest')
                y_batch = y_batch_large
                out_argmax = yy

            if pkl:
                out_argmax[0] = self.linknet_postprocess(out_argmax[0])
                segmented_imgs = decode_labels(out_argmax, 20)

            # print('mean preds ', out_argmax.mean())
            # np.save(self.args.out_dir + 'npy/' + str(cur_iteration) + '.npy', out_argmax[0])
            if self.args.data_mode == 'test':
                plt.imsave(self.args.out_dir + 'imgs/' + 'test_' + str(cur_iteration) + '.png', segmented_imgs[0])

            # log loss and acc
            acc_list += [acc]

            # log metrics
            if self.args.random_cropping:
                y1 = np.expand_dims(y_batch[0, :, :512], axis=0)
                y2 = np.expand_dims(y_batch[0, :, 512:], axis=0)
                y_batch = np.concatenate((y1, y2), axis=0)
                self.metrics.update_metrics(out_argmax, y_batch, 0, 0)
            else:
                self.metrics.update_metrics(out_argmax[0], y_batch[0], 0, 0)

        # mean over batches
        total_loss = 0
        total_acc = np.mean(acc_list)
        mean_iou = self.metrics.compute_final_metrics(self.test_data_len)

        # print in console
        tt.close()
        print("Here the statistics")
        print("Total_loss: " + str(total_loss))
        print("Total_acc: " + str(total_acc)[:6])
        print("mean_iou: " + str(mean_iou))

        print("Plotting imgs")
        for i in range(len(img_list)):
            plt.imsave(self.args.imgs_dir + 'test_' + str(i) + '.png', img_list[i])

    def test_eval(self, pkl=False):
        print("Testing mode will begin NOW..")

        # load the best model checkpoint to test on it
        if not pkl:
            self.load_best_model()

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.test_model.x_pl_before: x_batch,
                             self.test_model.is_training: False,
                             }
            else:
                feed_dict = {self.test_model.x_pl: x_batch,
                             self.test_model.is_training: False
                             }

            # run the feed_forward
            out_argmax, segmented_imgs = self.sess.run(
                [self.test_model.out_argmax,
                 self.test_model.segmented_summary],
                feed_dict=feed_dict)

            if pkl:
                out_argmax[0] = self.linknet_postprocess(out_argmax[0])
                segmented_imgs = decode_labels(out_argmax, 20)

            # Colored results for visualization
            colored_save_path = self.args.out_dir + 'imgs/' + str(self.names_mapper['Y'][idx])
            if not os.path.exists(os.path.dirname(colored_save_path)):
                os.makedirs(os.path.dirname(colored_save_path))
            plt.imsave(colored_save_path, segmented_imgs[0])

            # Results for official evaluation
            save_path = self.args.out_dir + 'results/' + str(self.names_mapper['Y'][idx])
            if not os.path.exists(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path))
            output = postprocess(out_argmax[0])
            misc.imsave(save_path, misc.imresize(output, [1024, 2048], 'nearest'))

            idx += 1

        # print in console
        tt.close()

    def test_inference(self):
        """
        Like the testing function but this one is for calculate the inference time
        and measure the frame per second
        """
        print("INFERENCE mode will begin NOW..")

        # load the best model checkpoint to test on it
        self.load_best_model()

        # output_node: network/output/Argmax
        # input_node: network/input/Placeholder
        #        for n in tf.get_default_graph().as_graph_def().node:
        #            if 'input' in n.name:#if 'Argmax' in n.name:
        #                import pdb; pdb.set_trace()
        print("Saving graph...")
        tf.train.write_graph(self.sess.graph_def, ".", 'graph.pb')
        print("Graph saved successfully.\n\n")
        exit(1)

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # create the FPS Meter
        fps_meter = FPSMeter()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            y_batch = self.test_data['Y'][idx:idx + 1]

            # update idx of mini_batch
            idx += 1

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.test_model.x_pl_before: x_batch,
                             self.test_model.y_pl_before: y_batch
                             #                             self.test_model.is_training: False,
                             }
            else:
                feed_dict = {self.test_model.x_pl: x_batch,
                             self.test_model.y_pl: y_batch
                             #                             self.test_model.is_training: False
                             }

            # calculate the time of one inference
            start = time.time()

            # run the feed_forward
            _ = self.sess.run(
                [self.test_model.out_argmax],
                feed_dict=feed_dict)

            # update the FPS meter
            fps_meter.update(time.time() - start)

        fps_meter.print_statistics()

    def finalize(self):
        self.reporter.finalize()
        self.summary_writer.close()
        self.save_model()

    def debug_layers(self):
        """
        This function will be responsible for output all outputs of all layers and dump them in a pickle

        :return:
        """
        print("Debugging mode will begin NOW..")

        layers = tf.get_collection('debug_layers')
        print("ALL Layers in the collection that i wanna to run {} layer".format(len(layers)))
        for layer in layers:
            print(layer)

        # exit(0)

        # reset metrics
        self.metrics.reset()

        print('mean image ', self.debug_x.mean())
        print('mean gt ', self.debug_y.mean())

        self.debug_y = self.linknet_preprocess_gt(self.debug_y)

        feed_dict = {self.test_model.x_pl: self.debug_x,
                     self.test_model.y_pl: self.debug_y,
                     self.test_model.is_training: False
                     }

        #        var = [v for v in tf.all_variables() if v.op.name == "network/decoder_block_4/deconv/deconv/weights"]
        #        conv_w= self.sess.run(var[0])
        #        var = [v for v in tf.all_variables() if v.op.name == "network/decoder_block_4/deconv/deconv/biases"]
        #        bias= self.sess.run(var[0])

        # run the feed_forward
        out_layers = self.sess.run(layers, feed_dict=feed_dict)
        for layer in out_layers:
            print(layer.shape)

        #        dict_out= torchfile.load('out_networks_layers/dict_out.t7')
        ##        init= tf.constant_initializer(conv_w)
        ##        conv_w1 = tf.get_variable('my_weights', [3,3,128,128], tf.float32, initializer=init, trainable=True)
        #        pp= tf.nn.relu(layers[39])
        #        out_relu= self.sess.run(pp, feed_dict={self.test_model.x_pl: self.debug_x,
        #                     self.test_model.y_pl: self.debug_y,
        #                     self.test_model.is_training: False
        #                     })
        ##        pp = tf.nn.conv2d_transpose(layers[39], conv_w, (1,32,64,128), strides=(1,2,2,1), padding="SAME")
        ##        pp= tf.image.resize_images(layers[39], (32,64))
        ##        pp = tf.nn.conv2d(pp, conv_w, strides=(1,1,1,1), padding="SAME")
        ##        bias1= tf.get_variable('my_bias', 128, tf.float32, tf.constant_initializer(bias))
        #        pp = tf.nn.bias_add(pp, bias)
        #        #self.sess.run(conv_w1.initializer)
        #        #self.sess.run(bias1.initializer)
        #        out_deconv= self.sess.run(pp, feed_dict={self.test_model.x_pl: self.debug_x,
        #                     self.test_model.y_pl: self.debug_y,
        #                     self.test_model.is_training: False
        #                     })
        #        out_deconv_direct= self.sess.run(layers[40], feed_dict={self.test_model.x_pl: self.debug_x,
        #                     self.test_model.y_pl: self.debug_y,
        #                     self.test_model.is_training: False
        #                     })
        #        pdb.set_trace()

        # print(out_layers)
        # exit(0)

        # dump them in a pickle
        with open("out_networks_layers/out_linknet_layers.pkl", "wb") as f:
            pickle.dump(out_layers, f, protocol=2)

        # run the feed_forward again to see argmax and segmented
        out_argmax, segmented_imgs = self.sess.run(
            [self.test_model.out_argmax,
             self.test_model.segmented_summary],
            feed_dict=feed_dict)

        print('mean preds ', out_argmax[0].mean())

        plt.imsave(self.args.out_dir + 'imgs/' + 'debug.png', segmented_imgs[0])

        self.metrics.update_metrics(out_argmax[0], self.debug_y, 0, 0)

        mean_iou = self.metrics.compute_final_metrics(1)

        print("mean_iou_of_debug: " + str(mean_iou))
示例#2
0
class Test(BasicTest):
    """
    Trainer class
    """
    name = 'Test'

    def __init__(self, args, sess, model):
        """
        Call the constructor of the base class
        init summaries
        init loading data
        :param args:
        :param sess:
        :param model:
        :return:
        """
        super().__init__(args, sess, model)
        # Init load data and generator
        self.generator = None
        self.run = None

        # 加载数据
        if self.args.data_mode == "realsense":
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_realsence_data()
        elif self.args.data_mode == "cityscapes_val":
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_val_data()
        elif self.args.data_mode == "cityscapes_test":
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_test_data()
        elif self.args.data_mode == "video":
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_vid_data()

        if self.args.task == "test":
            self.run = self.test
        elif self.args.task == "realsense":
            self.run = self.realsense_inference
        elif self.args.task == "realsense_imgs":
            self.run = self.realsense_imgs
        else:
            print("ERROR Please select a proper data_mode BYE")
            exit(-1)

        # Init metrics class
        self.metrics = Metrics(self.args.num_classes)
        # Init reporter class
        self.reporter = Reporter(self.args.out_dir + 'report_test.json', self.args)

    def resize(self, data):
        X = []
        Y = []
        for i in range(data['X'].shape[0]):
            X.append(misc.imresize(data['X'][i, ...], (self.args.img_height, self.args.img_width)))
            Y.append(misc.imresize(data['Y'][i, ...], (self.args.img_height, self.args.img_width), 'nearest'))
        data['X'] = np.asarray(X)
        data['Y'] = np.asarray(Y)
        return data

    @timeit
    def load_vid_data(self):
        print("Loading Video data..")
        self.test_data = {'X': np.load(self.args.data_dir + "X_vid.npy")}
        self.test_data['Y'] = np.zeros(self.test_data['X'].shape[:3])
        self.test_data_len = self.test_data['X'].shape[0]
        print("Vid-shape-x -- " + str(self.test_data['X'].shape))
        print("Vid-shape-y -- " + str(self.test_data['Y'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Video data is loaded")

    @timeit
    def load_val_data(self):
        print("Loading Validation data..")
        self.test_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
                          'Y': np.load(self.args.data_dir + "Y_val.npy")}
        self.test_data = self.resize(self.test_data)
        self.test_data['Y_large'] = self.test_data['Y']

        self.test_data_len = self.test_data['X'].shape[0] - self.test_data['X'].shape[0] % self.args.batch_size
        print("Validation-shape-x -- " + str(self.test_data['X'].shape))
        print("Validation-shape-y -- " + str(self.test_data['Y'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Validation data is loaded")

    @timeit
    def load_realsence_data(self):
        print("Loading RealSense data..")
        self.test_data = {'X': np.load(self.args.data_dir + "/realsense/x_inference.npy"),
                          'names': np.load(self.args.data_dir + "/realsense/name_inference.npy")}

        self.test_data_len = self.test_data['X'].shape[0] - self.test_data['X'].shape[0] % self.args.batch_size
        print("RealSense-shape-x -- " + str(self.test_data['X'].shape))
        print("RealSense-shape-name -- " + str(self.test_data['names'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("RealSense data is loaded")

    @timeit
    def load_test_data(self):
        print("Loading Testing data..")
        self.test_data = {'X': np.load(self.args.data_dir + "X_test.npy")}
        self.names_mapper = {'X': np.load(self.args.data_dir + "xnames_test.npy"),
                             'Y': np.load(self.args.data_dir + "ynames_test.npy")}
        self.test_data_len = self.test_data['X'].shape[0] - self.test_data['X'].shape[0] % self.args.batch_size
        print("Test-shape-x -- " + str(self.test_data['X'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Test data is loaded")

    def test_generator(self):
        start = 0
        new_epoch_flag = True
        idx = None
        while True:
            # init index array if it is a new_epoch
            if new_epoch_flag:
                if self.args.shuffle:
                    idx = np.random.choice(self.test_data_len, self.test_data_len, replace=False)
                else:
                    idx = np.arange(self.test_data_len)
                new_epoch_flag = False

            # select the mini_batches
            mask = idx[start:start + self.args.batch_size]
            x_batch = self.test_data['X'][mask]
            y_batch = self.test_data['Y'][mask]

            # update start idx
            start += self.args.batch_size

            if start >= self.test_data_len:
                start = 0
                new_epoch_flag = True

            yield x_batch, y_batch

    @staticmethod
    def linknet_postprocess(gt):
        gt2 = gt - 1
        gt2[gt == -1] = 19
        return gt2

    def test(self, pkl=False):
        print("Testing will begin NOW..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))
        # naming = np.load(self.args.data_dir + 'names_train.npy')

        # init acc and loss lists
        acc_list = []
        img_list = []

        # idx of image
        idx = 0

        # reset metrics
        self.metrics.reset()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            y_batch = self.test_data['Y'][idx:idx + 1]

            idx += 1

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.model.x_pl_before: x_batch,
                             self.model.y_pl_before: y_batch,
                             self.model.is_training: False,
                             }
            else:
                feed_dict = {self.model.x_pl: x_batch,
                             self.model.y_pl: y_batch,
                             self.model.is_training: False
                             }

            # run the feed_forward
            out_argmax, acc, segmented_imgs = self.sess.run(
                [self.model.out_argmax, self.model.accuracy,
                 # self.model.merged_summaries, self.model.segmented_summary],
                 self.model.segmented_summary],
                feed_dict=feed_dict)

            if pkl:
                out_argmax[0] = self.linknet_postprocess(out_argmax[0])
                segmented_imgs = decode_labels(out_argmax, 20)

            # print('mean preds ', out_argmax.mean())
            # np.save(self.args.out_dir + 'npy/' + str(cur_iteration) + '.npy', out_argmax[0])

            misc.imsave(self.args.out_dir + 'imgs/' + 'test_' + str(cur_iteration) + '.png', segmented_imgs[0])

            # log loss and acc
            acc_list += [acc]

            # log metrics
            if self.args.random_cropping:
                y1 = np.expand_dims(y_batch[0, :, :512], axis=0)
                y2 = np.expand_dims(y_batch[0, :, 512:], axis=0)
                y_batch = np.concatenate((y1, y2), axis=0)
                self.metrics.update_metrics(out_argmax, y_batch, 0, 0)
            else:
                self.metrics.update_metrics(out_argmax[0], y_batch[0], 0, 0)

        # mean over batches
        total_loss = 0
        total_acc = np.mean(acc_list)
        mean_iou = self.metrics.compute_final_metrics(self.test_data_len)

        # print in console
        tt.close()
        print("Here the statistics")
        print("Total_loss: " + str(total_loss))
        print("Total_acc: " + str(total_acc)[:6])
        print("mean_iou: " + str(mean_iou))

        print("Plotting imgs")
        for i in range(len(img_list)):
            misc.imsave(self.args.imgs_dir + 'test_' + str(i) + '.png', img_list[i])

    def realsense_imgs(self):
        print("realsense_imgs will begin NOW..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]

            idx += 1
            feed_dict = {self.model.x_pl: x_batch,
                         self.model.is_training: False
                         }

            # run the feed_forward
            segmented_imgs = self.sess.run([self.model.segmented_summary], feed_dict=feed_dict)

            # plt.imsave(self.args.out_dir + 'imgs/' + 'test_' + str(cur_iteration) + '.png', segmented_imgs[0][0])
            misc.imsave(self.args.out_dir + 'imgs/' + 'test_' + str(cur_iteration) + '.png', segmented_imgs[0][0])

        tt.close()
        print("realsense_imgs finished~")

    def test_eval(self, pkl=False):
        print("Testing will begin NOW..")

        # load the best model checkpoint to test on it
        if not pkl:
            self.load_best_model()

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.model.x_pl_before: x_batch,
                             self.model.is_training: False,
                             }
            else:
                feed_dict = {self.model.x_pl: x_batch,
                             self.model.is_training: False
                             }

            # run the feed_forward
            out_argmax, segmented_imgs = self.sess.run(
                [self.model.out_argmax,
                 self.model.segmented_summary],
                feed_dict=feed_dict)

            if pkl:
                out_argmax[0] = self.linknet_postprocess(out_argmax[0])
                segmented_imgs = decode_labels(out_argmax, 20)

            # Colored results for visualization
            colored_save_path = self.args.out_dir + 'imgs/' + str(self.names_mapper['Y'][idx])
            if not os.path.exists(os.path.dirname(colored_save_path)):
                os.makedirs(os.path.dirname(colored_save_path))
            misc.imsave(colored_save_path, segmented_imgs[0])

            # Results for official evaluation
            save_path = self.args.out_dir + 'results/' + str(self.names_mapper['Y'][idx])
            if not os.path.exists(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path))
            output = postprocess(out_argmax[0])
            misc.imsave(save_path, misc.imresize(output, [1024, 2048], 'nearest'))

            idx += 1

        # print in console
        tt.close()

    def realsense_inference(self):
        print("INFERENCE will begin NOW..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # create the FPS Meter
        fps_meter = FPSMeter()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            # y_batch = self.test_data['Y'][idx:idx + 1]

            # update idx of mini_batch
            idx += 1

            # Feed this variables to the network
            feed_dict = {self.model.x_pl: x_batch,
                         self.model.is_training: False}

            # calculate the time of one inference
            start = time.time()

            # run the feed_forward
            _ = self.sess.run(
                [self.model.out_argmax],
                feed_dict=feed_dict)

            # update the FPS meter
            fps_meter.update(time.time() - start)

        fps_meter.print_statistics()

    def test_inference(self):
        """
        Like the testing function but this one is for calculate the inference time
        and measure the frame per second
        """
        print("INFERENCE will begin NOW..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))

        # idx of image
        idx = 0

        # create the FPS Meter
        fps_meter = FPSMeter()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            y_batch = self.test_data['Y'][idx:idx + 1]

            # update idx of mini_batch
            idx += 1

            # Feed this variables to the network
            if self.args.random_cropping:
                feed_dict = {self.model.x_pl_before: x_batch,
                             self.model.y_pl_before: y_batch,
                             self.model.is_training: False,
                             }
            else:
                feed_dict = {self.model.x_pl: x_batch,
                             self.model.y_pl: y_batch,
                             self.model.is_training: False
                             }

            # calculate the time of one inference
            start = time.time()

            # run the feed_forward
            _ = self.sess.run(
                [self.model.out_argmax],
                feed_dict=feed_dict)

            # update the FPS meter
            fps_meter.update(time.time() - start)

        fps_meter.print_statistics()

    def finalize(self):
        self.reporter.finalize()
示例#3
0
class NewTrain(object):
    def __init__(self, args, sess, model):
        print("\nTraining is initializing itself\n")

        self.args = args
        self.sess = sess
        self.model = model

        # shortcut for model params
        self.params = self.model.params

        # To initialize all variables
        self.init = None
        self.init_model()

        # Create a saver object
        self.saver = tf.train.Saver(max_to_keep=self.args.max_to_keep,
                                    keep_checkpoint_every_n_hours=10,
                                    save_relative_paths=True)

        self.saver_best = tf.train.Saver(max_to_keep=1,
                                         save_relative_paths=True)

        # Load from latest checkpoint if found
        self.load_model()

        ##################################################################################
        # Init summaries

        # Summary variables
        self.scalar_summary_tags = [
            'mean_iou_on_val', 'train-loss-per-epoch', 'val-loss-per-epoch',
            'train-acc-per-epoch', 'val-acc-per-epoch'
        ]
        self.images_summary_tags = [
            ('train_prediction_sample',
             [None, self.params.img_height, self.params.img_width * 2, 3]),
            ('val_prediction_sample',
             [None, self.params.img_height, self.params.img_width * 2, 3])
        ]

        self.summary_tags = []
        self.summary_placeholders = {}
        self.summary_ops = {}
        # init summaries and it's operators
        self.init_summaries()
        # Create summary writer
        self.summary_writer = tf.summary.FileWriter(self.args.summary_dir,
                                                    self.sess.graph)
        ##################################################################################
        if self.args.mode == 'train':
            self.num_iterations_training_per_epoch = self.args.tfrecord_train_len // self.args.batch_size
            self.num_iterations_validation_per_epoch = self.args.tfrecord_val_len // self.args.batch_size
        else:
            self.test_data = None
            self.test_data_len = None
            self.num_iterations_testing_per_epoch = None
            self.load_test_data()
        ##################################################################################
        # Init metrics class
        self.metrics = Metrics(self.args.num_classes)
        # Init reporter class
        if self.args.mode == 'train' or 'overfit':
            self.reporter = Reporter(self.args.out_dir + 'report_train.json',
                                     self.args)
        elif self.args.mode == 'test':
            self.reporter = Reporter(self.args.out_dir + 'report_test.json',
                                     self.args)
            ##################################################################################

    @timeit
    def load_test_data(self):
        print("Loading Testing data..")
        self.test_data = {
            'X': np.load(self.args.data_dir + "X_val.npy"),
            'Y': np.load(self.args.data_dir + "Y_val.npy")
        }
        self.test_data_len = self.test_data['X'].shape[
            0] - self.test_data['X'].shape[0] % self.args.batch_size
        print("Test-shape-x -- " + str(self.test_data['X'].shape))
        print("Test-shape-y -- " + str(self.test_data['Y'].shape))
        self.num_iterations_testing_per_epoch = (self.test_data_len +
                                                 self.args.batch_size -
                                                 1) // self.args.batch_size
        print("Test data is loaded")

    @timeit
    def init_model(self):
        print("Initializing the variables of the model")
        self.init = tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer())
        self.sess.run(self.init)
        print("Initialization finished")

    def save_model(self):
        """
        Save Model Checkpoint
        :return:
        """
        print("saving a checkpoint")
        self.saver.save(self.sess, self.args.checkpoint_dir,
                        self.model.global_step_tensor)
        print("Saved a checkpoint")

    def save_best_model(self):
        """
        Save BEST Model Checkpoint
        :return:
        """
        print("saving a checkpoint for the best model")
        self.saver_best.save(self.sess, self.args.checkpoint_best_dir,
                             self.model.global_step_tensor)
        print("Saved a checkpoint for the best model")

    def load_best_model(self):
        """
        Load the best model checkpoint
        :return:
        """
        print("loading a checkpoint for BEST ONE")
        latest_checkpoint = tf.train.latest_checkpoint(
            self.args.checkpoint_best_dir)
        if latest_checkpoint:
            print(
                "Loading model checkpoint {} ...\n".format(latest_checkpoint))
            self.saver_best.restore(self.sess, latest_checkpoint)
        else:
            print("ERROR NO best checkpoint found")
            exit(-1)
        print("BEST MODEL LOADED..")

    def init_summaries(self):
        """
        Create the summary part of the graph
        :return:
        """
        with tf.variable_scope('train-summary-per-epoch'):
            for tag in self.scalar_summary_tags:
                self.summary_tags += tag
                self.summary_placeholders[tag] = tf.placeholder('float32',
                                                                None,
                                                                name=tag)
                self.summary_ops[tag] = tf.summary.scalar(
                    tag, self.summary_placeholders[tag])
            for tag, shape in self.images_summary_tags:
                self.summary_tags += tag
                self.summary_placeholders[tag] = tf.placeholder('float32',
                                                                shape,
                                                                name=tag)
                self.summary_ops[tag] = tf.summary.image(
                    tag, self.summary_placeholders[tag], max_outputs=10)

    def add_summary(self, step, summaries_dict=None, summaries_merged=None):
        """
        Add the summaries to tensorboard
        :param step:
        :param summaries_dict:
        :param summaries_merged:
        :return:
        """
        if summaries_dict is not None:
            summary_list = self.sess.run(
                [self.summary_ops[tag] for tag in summaries_dict.keys()], {
                    self.summary_placeholders[tag]: value
                    for tag, value in summaries_dict.items()
                })
            for summary in summary_list:
                self.summary_writer.add_summary(summary, step)
        if summaries_merged is not None:
            self.summary_writer.add_summary(summaries_merged, step)

    @timeit
    def load_model(self):
        """
        Load the latest checkpoint
        :return:
        """
        try:
            # This is for loading the pretrained weights if they can't be loaded during initialization.
            self.model.encoder.load_pretrained_weights(self.sess)
        except AttributeError:
            pass

        print("Searching for a checkpoint")
        latest_checkpoint = tf.train.latest_checkpoint(
            self.args.checkpoint_dir)
        if latest_checkpoint:
            print(
                "Loading model checkpoint {} ...\n".format(latest_checkpoint))
            self.saver.restore(self.sess, latest_checkpoint)
            print("Model loaded from the latest checkpoint\n")
        else:
            print("\n.. No ckpt, SO First time to train :D ..\n")

    def train(self):
        print("Training mode will begin NOW ..")
        tf.train.start_queue_runners(sess=self.sess)
        curr_lr = self.model.args.learning_rate
        for cur_epoch in range(
                self.model.global_epoch_tensor.eval(self.sess) + 1,
                self.args.num_epochs + 1, 1):

            # init tqdm and get the epoch value
            tt = tqdm(range(self.num_iterations_training_per_epoch),
                      total=self.num_iterations_training_per_epoch,
                      desc="epoch-" + str(cur_epoch) + "-")

            # init acc and loss lists
            loss_list = []
            acc_list = []

            # loop by the number of iterations
            for cur_iteration in tt:

                # get the cur_it for the summary
                cur_it = self.model.global_step_tensor.eval(self.sess)

                # Feed this variables to the network
                feed_dict = {
                    self.model.handle: self.model.training_handle,
                    self.model.is_training: True,
                    self.model.curr_learning_rate: curr_lr
                }

                # Run the feed forward but the last iteration finalize what you want to do
                if cur_iteration < self.num_iterations_training_per_epoch - 1:

                    # run the feed_forward
                    _, loss, acc, summaries_merged = self.sess.run(
                        [
                            self.model.train_op, self.model.loss,
                            self.model.accuracy, self.model.merged_summaries
                        ],
                        feed_dict=feed_dict)
                    # log loss and acc
                    loss_list += [loss]
                    acc_list += [acc]
                    # summarize
                    self.add_summary(cur_it, summaries_merged=summaries_merged)

                else:
                    # run the feed_forward
                    _, loss, acc, summaries_merged, segmented_imgs = self.sess.run(
                        [
                            self.model.train_op, self.model.loss,
                            self.model.accuracy, self.model.merged_summaries,
                            self.model.segmented_summary
                        ],
                        feed_dict=feed_dict)
                    # log loss and acc
                    loss_list += [loss]
                    acc_list += [acc]
                    total_loss = np.mean(loss_list)
                    total_acc = np.mean(acc_list)
                    # summarize
                    summaries_dict = dict()
                    summaries_dict['train-loss-per-epoch'] = total_loss
                    summaries_dict['train-acc-per-epoch'] = total_acc
                    summaries_dict['train_prediction_sample'] = segmented_imgs
                    self.add_summary(cur_it,
                                     summaries_dict=summaries_dict,
                                     summaries_merged=summaries_merged)

                    # report
                    self.reporter.report_experiment_statistics(
                        'train-acc', 'epoch-' + str(cur_epoch), str(total_acc))
                    self.reporter.report_experiment_statistics(
                        'train-loss', 'epoch-' + str(cur_epoch),
                        str(total_loss))
                    self.reporter.finalize()

                    # Update the Global step
                    self.model.global_step_assign_op.eval(
                        session=self.sess,
                        feed_dict={self.model.global_step_input: cur_it + 1})

                    # Update the Cur Epoch tensor
                    # it is the last thing because if it is interrupted it repeat this
                    self.model.global_epoch_assign_op.eval(
                        session=self.sess,
                        feed_dict={
                            self.model.global_epoch_input: cur_epoch + 1
                        })

                    # print in console
                    tt.close()
                    print("epoch-" + str(cur_epoch) + "-" + "loss:" +
                          str(total_loss) + "-" + " acc:" + str(total_acc)[:6])

                    # Break the loop to finalize this epoch
                    break

                # Update the Global step
                self.model.global_step_assign_op.eval(
                    session=self.sess,
                    feed_dict={self.model.global_step_input: cur_it + 1})

            # Save the current checkpoint
            if cur_epoch % self.args.save_every == 0:
                self.save_model()

            # Test the model on validation
            if cur_epoch % self.args.test_every == 0:
                self.test_per_epoch(
                    step=self.model.global_step_tensor.eval(self.sess),
                    epoch=self.model.global_epoch_tensor.eval(self.sess))

            if cur_epoch % self.args.learning_decay_every == 0:
                curr_lr = curr_lr * self.args.learning_decay
                print('Current learning rate is ', curr_lr)

        print("Training Finished")

    def test_per_epoch(self, step, epoch):
        print("Validation at step:" + str(step) + " at epoch:" + str(epoch) +
              " ..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.num_iterations_validation_per_epoch),
                  total=self.num_iterations_validation_per_epoch,
                  desc="Val-epoch-" + str(epoch) + "-")

        # init acc and loss lists
        loss_list = []
        acc_list = []
        inf_list = []

        # reset metrics
        self.metrics.reset()

        # get the maximum iou to compare with and save the best model
        max_iou = self.model.best_iou_tensor.eval(self.sess)

        # init dataset to validation
        self.sess.run(self.model.validation_iterator.initializer)

        # loop by the number of iterations
        for cur_iteration in tt:
            # Feed this variables to the network
            feed_dict = {
                self.model.handle: self.model.validation_handle,
                self.model.is_training: False
            }

            # Run the feed forward but the last iteration finalize what you want to do
            if cur_iteration < self.num_iterations_validation_per_epoch - 1:

                start = time.time()
                # run the feed_forward
                next_img, out_argmax, loss, acc = self.sess.run(
                    [
                        self.model.next_img, self.model.out_argmax,
                        self.model.loss, self.model.accuracy
                    ],
                    feed_dict=feed_dict)
                end = time.time()
                # log loss and acc
                loss_list += [loss]
                acc_list += [acc]
                inf_list += [end - start]
                # log metrics
                self.metrics.update_metrics_batch(out_argmax, next_img[1])

            else:
                start = time.time()
                # run the feed_forward
                next_img, out_argmax, loss, acc, segmented_imgs = self.sess.run(
                    [
                        self.model.next_img, self.model.out_argmax,
                        self.model.loss, self.model.accuracy,
                        self.model.segmented_summary
                    ],
                    feed_dict=feed_dict)
                end = time.time()
                # log loss and acc
                loss_list += [loss]
                acc_list += [acc]
                inf_list += [end - start]
                # log metrics
                self.metrics.update_metrics_batch(out_argmax, next_img[1])
                # mean over batches
                total_loss = np.mean(loss_list)
                total_acc = np.mean(acc_list)
                mean_iou = self.metrics.compute_final_metrics(
                    self.num_iterations_validation_per_epoch)
                mean_iou_arr = self.metrics.iou
                mean_inference = str(np.mean(inf_list)) + '-seconds'
                # summarize
                summaries_dict = dict()
                summaries_dict['val-loss-per-epoch'] = total_loss
                summaries_dict['val-acc-per-epoch'] = total_acc
                summaries_dict['mean_iou_on_val'] = mean_iou
                summaries_dict['val_prediction_sample'] = segmented_imgs
                self.add_summary(step, summaries_dict=summaries_dict)
                self.summary_writer.flush()

                # report
                self.reporter.report_experiment_statistics(
                    'validation-acc', 'epoch-' + str(epoch), str(total_acc))
                self.reporter.report_experiment_statistics(
                    'validation-loss', 'epoch-' + str(epoch), str(total_loss))
                self.reporter.report_experiment_statistics(
                    'avg_inference_time_on_validation', 'epoch-' + str(epoch),
                    str(mean_inference))
                self.reporter.report_experiment_validation_iou(
                    'epoch-' + str(epoch), str(mean_iou), mean_iou_arr)
                self.reporter.finalize()

                # print in console
                tt.close()
                print("Val-epoch-" + str(epoch) + "-" + "loss:" +
                      str(total_loss) + "-" + "acc:" + str(total_acc)[:6] +
                      "-mean_iou:" + str(mean_iou))
                print("Last_max_iou: " + str(max_iou))
                if mean_iou > max_iou:
                    print(
                        "This validation got a new best iou. so we will save this one"
                    )
                    # save the best model
                    self.save_best_model()
                    # Set the new maximum
                    self.model.best_iou_assign_op.eval(
                        session=self.sess,
                        feed_dict={self.model.best_iou_input: mean_iou})
                else:
                    print("hmm not the best validation epoch :/..")

                # Break the loop to finalize this epoch
                break

    def test(self):
        print("Testing mode will begin NOW..")

        # load the best model checkpoint to test on it
        self.load_best_model()

        # init tqdm and get the epoch value
        tt = tqdm(range(self.test_data_len))
        naming = np.load(self.args.data_dir + 'names_train.npy')

        # init acc and loss lists
        loss_list = []
        acc_list = []
        img_list = []

        # idx of image
        idx = 0

        # reset metrics
        self.metrics.reset()

        # loop by the number of iterations
        for cur_iteration in tt:
            # load mini_batches
            x_batch = self.test_data['X'][idx:idx + 1]
            y_batch = self.test_data['Y'][idx:idx + 1]

            # update idx of mini_batch
            idx += 1

            # Feed this variables to the network
            feed_dict = {
                self.model.x_pl: x_batch,
                self.model.y_pl: y_batch,
                self.model.is_training: False
            }

            # run the feed_forward
            out_argmax, loss, acc, summaries_merged, segmented_imgs = self.sess.run(
                [
                    self.model.out_argmax, self.model.loss,
                    self.model.accuracy, self.model.merged_summaries,
                    self.model.segmented_summary
                ],
                feed_dict=feed_dict)

            np.save(self.args.out_dir + 'npy/' + str(cur_iteration) + '.npy',
                    out_argmax[0])
            plt.imsave(
                self.args.out_dir + 'imgs/' + 'test_' + str(cur_iteration) +
                '.png', segmented_imgs[0])

            # log loss and acc
            loss_list += [loss]
            acc_list += [acc]

            # log metrics
            self.metrics.update_metrics(out_argmax[0], y_batch[0], 0, 0)

        # mean over batches
        total_loss = np.mean(loss_list)
        total_acc = np.mean(acc_list)
        mean_iou = self.metrics.compute_final_metrics(self.test_data_len)

        # print in console
        tt.close()
        print("Here the statistics")
        print("Total_loss: " + str(total_loss))
        print("Total_acc: " + str(total_acc)[:6])
        print("mean_iou: " + str(mean_iou))

        print("Plotting imgs")

    def finalize(self):
        self.reporter.finalize()
        self.summary_writer.close()
        self.save_model()
示例#4
0
class Train(BasicTrain):
    """
    Trainer class
    """
    name = 'Train'

    def __init__(self, args, sess, model):
        """
        Call the constructor of the base class
        init summaries
        init loading data
        :param args:
        :param sess:
        :param model:
        :return:
        """
        super().__init__(args, sess, model)
        ##################################################################################
        # Init summaries

        # Summary variables
        self.scalar_summary_tags = [
            'mean_iou_on_val', 'train-loss-per-epoch', 'val-loss-per-epoch',
            'train-acc-per-epoch', 'val-acc-per-epoch'
        ]
        self.images_summary_tags = [
            ('train_prediction_sample',
             [None, self.params.img_height, self.params.img_width * 2, 3]),
            ('val_prediction_sample',
             [None, self.params.img_height, self.params.img_width * 2, 3])
        ]
        # self.summary_tags = []
        self.summary_placeholders = {}
        self.summary_ops = {}
        # self.merged_summaries = None
        # init summaries and it's operators
        self.init_summaries()
        # Create summary writer
        self.summary_writer = tf.summary.FileWriter(self.args.summary_dir,
                                                    self.sess.graph)
        ##################################################################################
        # Init load data and generator
        self.generator = None
        self.run = None
        if self.args.data_mode == "experiment_tfdata":
            self.data_session = None
            self.init_op = None
            self.train_next_batch, self.train_data_len = self.init_tfdata(
                self.args.batch_size,
                self.args.abs_data_dir,
                (self.args.img_height, self.args.img_width),
                mode='train')
            self.num_iterations_training_per_epoch = self.train_data_len // self.args.batch_size
            self.generator = self.train_tfdata_generator
        elif self.args.data_mode == "experiment":
            self.train_data = None
            self.train_data_len = None
            self.val_data = None
            self.val_data_len = None
            self.num_iterations_training_per_epoch = None
            self.num_iterations_validation_per_epoch = None
            self.load_train_data()
            self.generator = self.train_generator
            self.run = self.train
        elif self.args.data_mode == "debug":
            print("Debugging photo loading..")
            # self.debug_x= misc.imread('/leftImg8bit/val/lindau/lindau_000048_000019_leftImg8bit.png')
            # self.debug_y= misc.imread('/gtFine/val/lindau/lindau_000048_000019_gtFine_labelIds.png')
            # self.debug_x= np.expand_dims(misc.imresize(self.debug_x, (512,1024)), axis=0)
            # self.debug_y= np.expand_dims(misc.imresize(self.debug_y, (512,1024)), axis=0)
            self.debug_x = np.load('data/debug/debug_x.npy')
            self.debug_y = np.load('data/debug/debug_y.npy')
            print("Debugging photo loaded")
        else:
            print("ERROR Please select a proper data_mode BYE")
            exit(-1)
        ##################################################################################
        # Init metrics class
        self.metrics = Metrics(self.args.num_classes)
        # Init reporter class
        self.reporter = Reporter(self.args.out_dir + 'report_train.json',
                                 self.args)
        ##################################################################################

    def crop(self):
        sh = self.val_data['X'].shape
        temp_val_data = {
            'X':
            np.zeros((sh[0] * 2, sh[1], sh[2] // 2, sh[3]),
                     self.val_data['X'].dtype),
            'Y':
            np.zeros((sh[0] * 2, sh[1], sh[2] // 2), self.val_data['Y'].dtype)
        }
        for i in range(sh[0]):
            temp_val_data['X'][i *
                               2, :, :, :] = self.val_data['X'][i, :, :sh[2] //
                                                                2, :]
            temp_val_data['X'][i * 2 +
                               1, :, :, :] = self.val_data['X'][i, :,
                                                                sh[2] // 2:, :]
            temp_val_data['Y'][i *
                               2, :, :] = self.val_data['Y'][i, :, :sh[2] // 2]
            temp_val_data['Y'][i * 2 +
                               1, :, :] = self.val_data['Y'][i, :, sh[2] // 2:]

        self.val_data = temp_val_data

    def init_tfdata(self, batch_size, main_dir, resize_shape, mode='train'):
        self.data_session = tf.Session()
        print("Creating the iterator for training data")
        with tf.device('/cpu:0'):
            segdl = SegDataLoader(
                main_dir,
                batch_size,
                (resize_shape[0], resize_shape[1]),
                resize_shape,
                # * 2), resize_shape,
                'data/cityscapes_tfdata/train.txt')
            iterator = tf.data.Iterator.from_structure(
                segdl.data_tr.output_types, segdl.data_tr.output_shapes)
            next_batch = iterator.get_next()

            self.init_op = iterator.make_initializer(segdl.data_tr)
            self.data_session.run(self.init_op)

        print("Loading Validation data in memoryfor faster training..")
        self.val_data = {
            'X': np.load(self.args.data_dir + "X_val.npy"),
            'Y': np.load(self.args.data_dir + "Y_val.npy")
        }
        # self.crop()
        # import cv2
        # cv2.imshow('crop1', self.val_data['X'][0,:,:,:])
        # cv2.imshow('crop2', self.val_data['X'][1,:,:,:])
        # cv2.imshow('seg1', self.val_data['Y'][0,:,:])
        # cv2.imshow('seg2', self.val_data['Y'][1,:,:])
        # cv2.waitKey()

        self.val_data_len = self.val_data['X'].shape[
            0] - self.val_data['X'].shape[0] % self.args.batch_size
        # self.num_iterations_validation_per_epoch =
        # (self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        self.num_iterations_validation_per_epoch = self.val_data_len // self.args.batch_size

        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " +
              str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " +
              str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

        return next_batch, segdl.data_len

    @timeit
    def load_overfit_data(self):
        print("Loading data..")
        self.train_data = {
            'X': np.load(self.args.data_dir + "X_train.npy"),
            'Y': np.load(self.args.data_dir + "Y_train.npy")
        }
        self.train_data_len = self.train_data['X'].shape[
            0] - self.train_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_training_per_epoch = (self.train_data_len +
                                                  self.args.batch_size -
                                                  1) // self.args.batch_size
        print("Train-shape-x -- " + str(self.train_data['X'].shape))
        print("Train-shape-y -- " + str(self.train_data['Y'].shape))
        print("Num of iterations in one epoch -- " +
              str(self.num_iterations_training_per_epoch))
        print("Overfitting data is loaded")

        print("Loading Validation data..")
        self.val_data = self.train_data
        self.val_data_len = self.val_data['X'].shape[
            0] - self.val_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_validation_per_epoch = (self.val_data_len +
                                                    self.args.batch_size -
                                                    1) // self.args.batch_size
        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " +
              str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " +
              str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

    def overfit_generator(self):
        start = 0
        new_epoch_flag = True
        idx = None
        while True:
            # init index array if it is a new_epoch
            if new_epoch_flag:
                if self.args.shuffle:
                    idx = np.random.choice(self.train_data_len,
                                           self.train_data_len,
                                           replace=False)
                else:
                    idx = np.arange(self.train_data_len)
                new_epoch_flag = False

            # select the mini_batches
            mask = idx[start:start + self.args.batch_size]
            x_batch = self.train_data['X'][mask]
            y_batch = self.train_data['Y'][mask]

            start += self.args.batch_size
            if start >= self.train_data_len:
                start = 0
                new_epoch_flag = True

            yield x_batch, y_batch

    def init_summaries(self):
        """
        Create the summary part of the graph
        :return:
        """
        with tf.variable_scope('train-summary-per-epoch'):
            for tag in self.scalar_summary_tags:
                # self.summary_tags += tag
                self.summary_placeholders[tag] = tf.placeholder('float32',
                                                                None,
                                                                name=tag)
                self.summary_ops[tag] = tf.summary.scalar(
                    tag, self.summary_placeholders[tag])
            for tag, shape in self.images_summary_tags:
                # self.summary_tags += tag
                self.summary_placeholders[tag] = tf.placeholder('float32',
                                                                shape,
                                                                name=tag)
                self.summary_ops[tag] = tf.summary.image(
                    tag, self.summary_placeholders[tag], max_outputs=10)

        # self.merged_summaries = tf.summary.merge_all()
        # s = tf.get_collection(tf.GraphKeys.SUMMARIES)
        # for i in s:
        #     if i.name == 'train-summary-per-epoch/train_prediction_sample_1:0':
        #         print(i.name)

    def add_summary(self, step, summaries_dict=None, summaries_merged=None):
        """
        Add the summaries to tensorboard
        :param step:
        :param summaries_dict:
        :param summaries_merged:
        :return:
        """
        if summaries_dict is not None:
            summary_list = self.sess.run(
                [self.summary_ops[tag] for tag in summaries_dict.keys()], {
                    self.summary_placeholders[tag]: value
                    for tag, value in summaries_dict.items()
                })
            for summary in summary_list:
                self.summary_writer.add_summary(summary, step)
        if summaries_merged is not None:
            self.summary_writer.add_summary(summaries_merged, step)

    @timeit
    def load_train_data(self):
        print("Loading Training data..")
        self.train_data = {
            'X': np.load(self.args.data_dir + "X_train.npy"),
            'Y': np.load(self.args.data_dir + "Y_train.npy")
        }
        self.train_data = self.resize(self.train_data)
        self.train_data_len = self.train_data['X'].shape[0]

        self.num_iterations_training_per_epoch = (self.train_data_len +
                                                  self.args.batch_size -
                                                  1) // self.args.batch_size

        print("Train-shape-x -- " + str(self.train_data['X'].shape) + " " +
              str(self.train_data_len))
        print("Train-shape-y -- " + str(self.train_data['Y'].shape))
        print("Num of iterations on training data in one epoch -- " +
              str(self.num_iterations_training_per_epoch))
        print("Training data is loaded")

        print("Loading Validation data..")
        self.val_data = {
            'X': np.load(self.args.data_dir + "X_val.npy"),
            'Y': np.load(self.args.data_dir + "Y_val.npy")
        }
        self.val_data['Y_large'] = self.val_data['Y']

        self.val_data_len = self.val_data['X'].shape[
            0] - self.val_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_validation_per_epoch = (self.val_data_len +
                                                    self.args.batch_size -
                                                    1) // self.args.batch_size
        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " +
              str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " +
              str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

    def train_generator(self):
        start = 0
        idx = np.random.choice(self.train_data_len,
                               self.num_iterations_training_per_epoch *
                               self.args.batch_size,
                               replace=True)
        while True:
            # select the mini_batches
            mask = idx[start:start + self.args.batch_size]
            x_batch = self.train_data['X'][mask]
            y_batch = self.train_data['Y'][mask]

            # update start idx
            start += self.args.batch_size

            yield x_batch, y_batch

            if start >= self.train_data_len:
                return

    def train_tfdata_generator(self):
        with tf.device('/cpu:0'):
            while True:
                x_batch, y_batch = self.data_session.run(self.train_next_batch)
                yield x_batch, y_batch[:, :, :, 0]

    def resize(self, data):
        X = []
        Y = []
        for i in range(data['X'].shape[0]):
            X.append(
                misc.imresize(data['X'][i, ...],
                              (self.args.img_height, self.args.img_width)))
            Y.append(
                misc.imresize(data['Y'][i, ...],
                              (self.args.img_height, self.args.img_width),
                              'nearest'))
        data['X'] = np.asarray(X)
        data['Y'] = np.asarray(Y)
        return data

    def train(self):
        print("Training will begin NOW ..")
        # curr_lr= self.model.args.learning_rate
        for cur_epoch in range(
                self.model.global_epoch_tensor.eval(self.sess) + 1,
                self.args.num_epochs + 1, 1):

            # init tqdm and get the epoch value
            tt = tqdm(self.generator(),
                      total=self.num_iterations_training_per_epoch,
                      desc="epoch-" + str(cur_epoch) + "-")

            # init the current iterations
            cur_iteration = 0

            # init acc and loss lists
            loss_list = []
            acc_list = []

            # loop by the number of iterations
            for x_batch, y_batch in tt:

                # get the cur_it for the summary
                cur_it = self.model.global_step_tensor.eval(self.sess)

                # Feed this variables to the network
                feed_dict = {
                    self.model.x_pl: x_batch,
                    self.model.y_pl: y_batch,
                    self.model.is_training: True
                    # self.model.curr_learning_rate:curr_lr
                }

                # Run the feed forward but the last iteration finalize what you want to do
                if cur_iteration < self.num_iterations_training_per_epoch - 1:

                    # run the feed_forward
                    _, loss, acc = self.sess.run([
                        self.model.train_op, self.model.loss,
                        self.model.accuracy
                    ],
                                                 feed_dict=feed_dict)
                    # log loss and acc
                    loss_list += [loss]
                    acc_list += [acc]
                    # summarize
                    # self.add_summary(cur_it, summaries_merged=summaries_merged)

                else:
                    # run the feed_forward

                    _, loss, acc, summaries_merged, segmented_imgs = self.sess.run(
                        [
                            self.model.train_op, self.model.loss,
                            self.model.accuracy, self.model.merged_summaries,
                            self.model.segmented_summary
                        ],
                        feed_dict=feed_dict)

                    # log loss and acc
                    loss_list += [loss]
                    acc_list += [acc]
                    total_loss = np.mean(loss_list)
                    total_acc = np.mean(acc_list)
                    # summarize
                    summaries_dict = dict()
                    summaries_dict['train-loss-per-epoch'] = total_loss
                    summaries_dict['train-acc-per-epoch'] = total_acc
                    summaries_dict['train_prediction_sample'] = segmented_imgs

                    self.add_summary(cur_it,
                                     summaries_dict=summaries_dict,
                                     summaries_merged=summaries_merged)

                    # report
                    self.reporter.report_experiment_statistics(
                        'train-acc', 'epoch-' + str(cur_epoch), str(total_acc))
                    self.reporter.report_experiment_statistics(
                        'train-loss', 'epoch-' + str(cur_epoch),
                        str(total_loss))
                    self.reporter.finalize()

                    # Update the Global step
                    self.model.global_step_assign_op.eval(
                        session=self.sess,
                        feed_dict={self.model.global_step_input: cur_it + 1})

                    # Update the Cur Epoch tensor
                    # it is the last thing because if it is interrupted it repeat this
                    self.model.global_epoch_assign_op.eval(
                        session=self.sess,
                        feed_dict={
                            self.model.global_epoch_input: cur_epoch + 1
                        })

                    # print in console
                    tt.close()
                    print("epoch-" + str(cur_epoch) + "-" + "loss:" +
                          str(total_loss) + "-" + " acc:" + str(total_acc)[:6])

                    # Break the loop to finalize this epoch
                    break

                # Update the Global step
                self.model.global_step_assign_op.eval(
                    session=self.sess,
                    feed_dict={self.model.global_step_input: cur_it + 1})

                # update the cur_iteration
                cur_iteration += 1

            # Save the current checkpoint
            if cur_epoch % self.args.save_every == 0:
                self.save_model()

            # Test the model on validation
            if cur_epoch % self.args.test_every == 0:
                self.test_per_epoch(
                    step=self.model.global_step_tensor.eval(self.sess),
                    epoch=self.model.global_epoch_tensor.eval(self.sess))

        print("Training Finished")

    def test_per_epoch(self, step, epoch):
        print("Validation at step:" + str(step) + " at epoch:" + str(epoch) +
              " ..")

        # init tqdm and get the epoch value
        tt = tqdm(range(self.num_iterations_validation_per_epoch),
                  total=self.num_iterations_validation_per_epoch,
                  desc="Val-epoch-" + str(epoch) + "-")

        # init acc and loss lists
        loss_list = []
        acc_list = []
        inf_list = []

        # idx of minibatch
        idx = 0

        # reset metrics
        self.metrics.reset()

        # get the maximum iou to compare with and save the best model
        max_iou = self.model.best_iou_tensor.eval(self.sess)

        # loop by the number of iterations
        for cur_iteration in tt:
            # load minibatches
            x_batch = self.val_data['X'][idx:idx + self.args.batch_size]
            y_batch = self.val_data['Y'][idx:idx + self.args.batch_size]
            # if self.args.data_mode == 'experiment_v2':
            #     y_batch_large = self.val_data['Y_large'][idx:idx + self.args.batch_size]

            # update idx of minibatch
            idx += self.args.batch_size

            # Feed this variables to the network
            feed_dict = {
                self.model.x_pl: x_batch,
                self.model.y_pl: y_batch,
                self.model.is_training: False
            }

            # Run the feed forward but the last iteration finalize what you want to do
            if cur_iteration < self.num_iterations_validation_per_epoch - 1:

                start = time.time()
                # run the feed_forward

                out_argmax, loss, acc = self.sess.run([
                    self.model.out_argmax, self.model.loss, self.model.accuracy
                ],
                                                      feed_dict=feed_dict)

                end = time.time()
                # log loss and acc
                loss_list += [loss]
                acc_list += [acc]
                inf_list += [end - start]

                # log metrics
                self.metrics.update_metrics_batch(out_argmax, y_batch)

            else:
                start = time.time()
                # run the feed_forward
                out_argmax, acc, segmented_imgs = self.sess.run(
                    [
                        self.model.out_argmax, self.model.accuracy,
                        self.model.segmented_summary
                    ],
                    feed_dict=feed_dict)

                end = time.time()
                # log loss and acc
                acc_list += [acc]
                inf_list += [end - start]
                # log metrics
                self.metrics.update_metrics_batch(out_argmax, y_batch)
                # mean over batches
                total_acc = np.mean(acc_list)
                mean_iou = self.metrics.compute_final_metrics(
                    self.num_iterations_validation_per_epoch)
                mean_iou_arr = self.metrics.iou
                mean_inference = str(np.mean(inf_list)) + '-seconds'
                # summarize
                summaries_dict = dict()
                summaries_dict['val-acc-per-epoch'] = total_acc
                summaries_dict['mean_iou_on_val'] = mean_iou
                summaries_dict['val_prediction_sample'] = segmented_imgs
                self.add_summary(step, summaries_dict=summaries_dict)

                # report
                self.reporter.report_experiment_statistics(
                    'validation-acc', 'epoch-' + str(epoch), str(total_acc))
                self.reporter.report_experiment_statistics(
                    'avg_inference_time_on_validation', 'epoch-' + str(epoch),
                    str(mean_inference))
                self.reporter.report_experiment_validation_iou(
                    'epoch-' + str(epoch), str(mean_iou), mean_iou_arr)
                self.reporter.finalize()

                # print in console
                tt.close()
                print("Val-epoch-" + str(epoch) + "-" + "acc:" +
                      str(total_acc)[:6] + "-mean_iou:" + str(mean_iou))
                print("Last_max_iou: " + str(max_iou))
                if mean_iou > max_iou:
                    print(
                        "This validation got a new best iou. so we will save this one"
                    )
                    # save the best model
                    self.save_best_model()
                    # Set the new maximum
                    self.model.best_iou_assign_op.eval(
                        session=self.sess,
                        feed_dict={self.model.best_iou_input: mean_iou})
                else:
                    print("hmm not the best validation epoch :/..")
                break

    def finalize(self):
        self.reporter.finalize()
        self.summary_writer.close()
        self.save_model()