Example #1
0
    def _full_validation(self, vali_data, sess):
        tflearn_dev.is_training(True)
        num_batches_vali = FLAGS.num_eval_images // FLAGS.test_batch_size

        loss_list = []
        accuracy_list = []

        start_time = time.time()

        for step_vali in range(num_batches_vali):
            loss, accuracy = sess.run([self.loss, self.accuracy],
                                      feed_dict={self.am_training: False})

            #sess.run([self.batch_data, self.batch_labels],
            #								feed_dict={self.am_training: False})

            loss_list.append(loss)
            accuracy_list.append(accuracy)
        duration = time.time() - start_time
        #print(duration)

        vali_loss_value = np.mean(np.array(loss_list))
        vali_accuracy_value = np.mean(np.array(accuracy_list))

        return vali_loss_value, vali_accuracy_value
Example #2
0
    def _full_validation(self, sess):
        """ Validation.
        After each epoch training, the loss value and accuracy will be 
        calculated for the validation dataset.
        The output can be used to implement early stopping.
        """
        tflearn.is_training(False, session=sess)
        num_batches_vali = FLAGS.num_val_images // FLAGS.batch_size

        loss_list = []
        accuracy_list = []

        for step_vali in range(num_batches_vali):
            _, _, loss, accuracy = sess.run(
                [self.batch_data, self.batch_labels, self.loss, self.accuracy],
                feed_dict={
                    self.am_training: False,
                    self.prob_fc: 1,
                    self.prob_conv: 1
                })
            #feed_dict={self.am_training: False, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: 1})
            #accuracy = 0
            loss_list.append(loss)
            accuracy_list.append(accuracy)

        vali_loss_value = np.mean(np.array(loss_list))
        vali_accuracy_value = np.mean(np.array(accuracy_list))

        return vali_loss_value, vali_accuracy_value
Example #3
0
    def _full_validation(self, vali_data, sess):
        tflearn.is_training(True)
        num_batches_vali = FLAGS.num_eval_images // FLAGS.train_batch_size

        loss_list = []
        accuracy_list = []

        for step_vali in range(num_batches_vali):
            loss, accuracy = sess.run([self.loss, self.accuracy],
                                      feed_dict={self.am_training: False})

            loss_list.append(loss)
            accuracy_list.append(accuracy)

        vali_loss_value = np.mean(np.array(loss_list))
        vali_accuracy_value = np.mean(np.array(accuracy_list))

        return vali_loss_value, vali_accuracy_value
Example #4
0
	def _full_validation(self, sess):
		tflearn.is_training(False, session=sess)
		num_batches_vali = FLAGS.num_val_images // FLAGS.batch_size

		loss_list = []
		accuracy_list = []

		for step_vali in range(num_batches_vali):
			_, _, loss, accuracy = sess.run([self.batch_data, self.batch_labels,self.loss, self.accuracy], 
				feed_dict={self.am_training:False, self.prob_fc: 1, self.prob_conv: 1})
											#feed_dict={self.am_training: False, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: 1})
			
			loss_list.append(loss)
			accuracy_list.append(accuracy)

		vali_loss_value = np.mean(np.array(loss_list))
		vali_accuracy_value = np.mean(np.array(accuracy_list))

		return vali_loss_value, vali_accuracy_value
Example #5
0
    def train(self, **kwargs):
        ops.reset_default_graph()
        sess = tf.Session(config=self.tf_config)

        with sess.as_default():
            # Data Reading objects
            train_data = ReadData('train', self.img_size, self.set_id,
                                  self.train_range)
            vali_data = ReadData('validation', self.img_size, self.set_id,
                                 self.vali_range)

            train_batch_data, train_batch_labels = train_data.read_from_files()
            vali_batch_data, vali_batch_labels = vali_data.read_from_files()
            self.am_training = tf.placeholder(dtype=bool, shape=())
            self.batch_data = tf.cond(self.am_training,
                                      lambda: train_batch_data,
                                      lambda: vali_batch_data)
            self.batch_labels = tf.cond(self.am_training,
                                        lambda: train_batch_labels,
                                        lambda: vali_batch_labels)

            self._build_graph()
            self.saver = tf.train.Saver(tf.global_variables())
            # Build an initialization operation to run below
            init = tf.global_variables_initializer()
            sess.run(init)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            # This summary writer object helps write summaries on tensorboard
            summary_writer = tf.summary.FileWriter(FLAGS.log_dir + self.run_id)
            summary_writer.add_graph(sess.graph)

            train_error_list = []
            val_error_list = []

            print('Start training...')
            print('----------------------------------')

            train_steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size
            report_freq = train_steps_per_epoch

            train_steps = FLAGS.train_epoch * train_steps_per_epoch

            durations = []
            train_loss_list = []
            train_total_loss_list = []
            train_accuracy_list = []

            best_accuracy = 0

            for step in range(train_steps):
                tflearn.is_training(True)
                #print('{} step starts'.format(step))

                start_time = time.time()

                _, summary_str, loss_value, total_loss, accuracy = sess.run(
                    [
                        self.train_op, self.summary_op, self.loss,
                        self.total_loss, self.accuracy
                    ],
                    feed_dict={self.am_training: True})
                #sess.run(self.batch_labels,feed_dict={self.am_training: True})

                duration = time.time() - start_time
                #print('{} step starts {}'.format(step, duration))

                durations.append(duration)
                train_loss_list.append(loss_value)
                train_total_loss_list.append(total_loss)
                train_accuracy_list.append(accuracy)

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step % report_freq == 0:
                    start_time = time.time()

                    summary_writer.add_summary(summary_str, step)

                    sec_per_report = np.sum(np.array(durations))
                    train_loss = np.mean(np.array(train_loss_list))
                    train_total_loss = np.mean(np.array(train_total_loss_list))
                    train_accuracy_value = np.mean(
                        np.array(train_accuracy_list))

                    train_loss_list = []
                    train_total_loss_list = []
                    train_accuracy_list = []
                    durations = []

                    train_summ = tf.Summary()
                    train_summ.value.add(tag="train_loss",
                                         simple_value=train_loss.astype(
                                             np.float))
                    train_summ.value.add(tag="train_total_loss",
                                         simple_value=train_total_loss.astype(
                                             np.float))
                    train_summ.value.add(
                        tag="train_accuracy",
                        simple_value=train_accuracy_value.astype(np.float))

                    summary_writer.add_summary(train_summ, step)

                    vali_loss_value, vali_accuracy_value = self._full_validation(
                        vali_data, sess)

                    if vali_accuracy_value > best_accuracy:
                        best_accuracy = vali_accuracy_value

                        model_dir = os.path.join(FLAGS.log_dir, self.run_id,
                                                 'model')
                        if not os.path.isdir(model_dir):
                            os.mkdir(model_dir)
                        checkpoint_path = os.path.join(
                            model_dir,
                            'vali_{:.3f}'.format(vali_accuracy_value))

                        self.saver.save(sess,
                                        checkpoint_path,
                                        global_step=step)

                    vali_summ = tf.Summary()
                    vali_summ.value.add(tag="vali_loss",
                                        simple_value=vali_loss_value.astype(
                                            np.float))
                    vali_summ.value.add(
                        tag="vali_accuracy",
                        simple_value=vali_accuracy_value.astype(np.float))

                    summary_writer.add_summary(vali_summ, step)
                    summary_writer.flush()

                    vali_duration = time.time() - start_time

                    format_str = (
                        'Epoch %d, loss = %.4f, total_loss = %.4f, accuracy = %.4f, vali_loss = %.4f, vali_accuracy = %.4f (%.3f '
                        'sec/report)')
                    print(
                        format_str %
                        (step // report_freq, train_loss, train_total_loss,
                         train_accuracy_value, vali_loss_value,
                         vali_accuracy_value, sec_per_report + vali_duration))
Example #6
0
def main(model, data_path, output_path, test_range=None):
    h5f = h5py.File(data_path, 'r')
    datatype = 'val'
    #datatype = FLAGS.dataset_test
    dice_all_epi = []
    dice_all_endo = []
    dice_all_wall = []

    #sn_list = []; sp_list = []; acc_list = []; all_dice = []

    #output_all = open(output_path+'/all_patients.txt', 'a')
    #output_sum = open(output_path+'/summary.txt', 'a')
    if not os.path.isdir(output_path):
        os.mkdir(output_path)

    log_f = open(os.path.join(output_path, 'result_info'), 'a')
    for dataset in ['val']:
        #for dataset in ['train']:
        # Find patient ID in each dataset
        if datatype == 'Sunny':
            subjects = list(h5f[dataset + '/location/'].keys())
        else:
            subjects = list(h5f['location/'].keys())
            test_range = list(test_range)
            subjects = [subjects[i] for i in test_range]
            print(subjects)

        # Read MRI slices for each patient
        sess = tf.Session(config=tf_config)
        save_root = FLAGS.save_root_for_prediction
        with sess.as_default():
            tflearn.is_training(False, session=sess)
            batch_data = tf.placeholder(tf.float32,
                                        shape=[1, 256, 256, 1],
                                        name='batch_data')
            batch_label = tf.placeholder(tf.float32,
                                         shape=[1, 256, 256, 3],
                                         name='batch_label')

            with tf.variable_scope(FLAGS.net_name):
                logits = getattr(network, FLAGS.net_name)(inputs=batch_data,
                                                          prob_fc=1,
                                                          prob_conv=1,
                                                          wd=0,
                                                          wd_scale=0,
                                                          training_phase=False)
                logits = logits[0]

            print(logits.shape)

            label = batch_label[:, :, :, 1:3]
            label = tf.greater(label, tf.ones(label.shape) * 0.5)
            label = tf.cast(label, tf.float32)
            #label = tf.cast(tf.argmax(self.label_pl[:,:,:,0:2],3), tf.float32)
            #label = tf.expand_dims(label, -1)

            al_true = getattr(network, 'adversatial_2')(batch_data,
                                                        label,
                                                        reuse=False)
            al_true = tf.nn.softmax(al_true)

            y_pred = tf.nn.softmax(logits[0])
            y_pred = tf.reshape(y_pred, [1, 256, 256, 3])
            maxvalue = tf.reduce_max(y_pred, axis=-1)
            print(y_pred.shape)
            y_pred = tf.equal(
                y_pred, tf.stack([maxvalue, maxvalue, maxvalue], axis=-1))
            pred = tf.cast(y_pred[:, :, :, 1:], tf.float32)

            #print(pred)
            al_false = getattr(network, 'adversatial_2')(batch_data,
                                                         pred,
                                                         reuse=True)
            al_false = tf.nn.softmax(al_false)

            saver = tf.train.Saver(tf.global_variables())
            saver.restore(sess, model_path)
            print('Model restored from ', model_path)

            for Pid in range(len(subjects)):
                path = '/home/spc/Documents/AD/' + str(Pid) + '/'
                if not os.path.isdir(path):
                    os.mkdir(path)
                print(Pid)
                Patient_No = subjects[Pid]
                # Read input, label and location for each patient
                if datatype == 'Sunny':
                    Input = h5f[dataset + '/input/%s/' % Patient_No][:, :, :,
                                                                     0]
                    Input = Input / np.max(Input)
                    Label = h5f[dataset + '/label/%s/' % Patient_No][:]
                    #Location = h5f[dataset+'/location/%s/'%Patient_No][:]
                else:
                    Input = h5f['/input/%s/' % Patient_No][:, :, :, 0]
                    Input = Input / np.max(Input)
                    Label = h5f['/label/%s/' % Patient_No][:]
                    #Location = h5f['/location/%s/'%Patient_No][:]
                #slice_idx_small = np.where(Location[:,1].reshape([-1])>10)[0]

                # Rescale the images
                #Input = rescale(Input, 0.8)
                #Label = rescale(Label, 0.8)
                # Resize the images if necessary

                # Make the prediction. From the predicted socre to generate final mask
                Input = Input[..., np.newaxis]

                num_batches = Input.shape[0]

                #var = [v for v in tf.global_variables() if v.name == 'prediction'][0]
                prediction_array = []
                true_list = np.zeros(num_batches)
                false_list = np.zeros(num_batches)
                for step in range(num_batches):
                    true_label, false_label, Labels, Preds, Images = sess.run(
                        [al_true, al_false, label, pred, batch_data],
                        feed_dict={
                            batch_data: Input[step:step + 1],
                            batch_label: Label[step:step + 1]
                        })
                    true_list[step] = np.argmax(true_label)
                    false_list[step] = 1 - np.argmax(false_label)

                    name = '{0}_{1:.03}_{2:.03}_{3:.03}'.format(
                        step, Jaccard(Labels[0, :, :, 0], Preds[0, :, :, 0]),
                        true_label[0, 0], false_label[0, 0])
                    #name = '{0:.03}_{1:.03}_'.format(-np.mean(dice_list_epi), -np.mean(dice_list_endo))
                    #draw_output(p[:,:,1]+0.5*p[:,:,2], Input[i,:,:,0], Label[i,:,:,1]+0.5*Label[i,:,:,2],
                    #    subject=Patient_No, name=i, dir_name=name)
                    """
                    plt.imshow(np.hstack((Images[0,:,:,0],Images[0,:,:,0],Images[0,:,:,0])), 'gray', interpolation='none')
                    label_all = np.hstack((np.zeros(Images[0,:,:,0].shape), Labels[0,:,:,0], Preds[0,:,:,0]))
                    mask = label_all.astype(np.int32)
                    masked = np.ma.masked_where(mask == 0, mask)
                    plt.imshow(masked, interpolation='none', alpha=0.4,cmap='hsv')

                    plt.axis('off')
                    plt.savefig(output_path + str(name) + '.png', transparent=True)
                    """

                    concat_img = np.hstack(
                        (Images[0, :, :, 0], Labels[0, :, :, 0],
                         Preds[0, :, :, 0])) * 255
                    cv2.imwrite(path + str(name) + '.png', concat_img)
                    log_f.write('{0:.03}\t{1:.03}\t{2:.03}\n'.format(
                        Jaccard(Labels[0, :, :, 0], Preds[0, :, :, 0]),
                        true_label[0, 0], false_label[0, 0]))
                    #print(prediction_array.shape, prediction.shape)
                #print('{};  True_label: {};  False_Label: {}')
                print(np.mean(true_list), np.mean(false_list))

            log_f.close()
            if datatype != 'Sunny':
                break
Example #7
0
	def train(self, **kwargs):
		#with tf.Graph().as_default():
		ops.reset_default_graph()
		sess = tf.Session(config=self.tf_config)

		with sess.as_default():
			# Data Reading objects
			tflearn.is_training(True, session=sess)

			self.am_training = tf.placeholder(dtype=bool, shape=())
			self.prob_fc = tf.placeholder_with_default(0.5, shape=())
			self.prob_conv = tf.placeholder_with_default(0.5, shape=())

			data_fn = functools.partial(input_fn, data_dir=os.path.join(FLAGS.data_dir, FLAGS.set_id), 
				num_shards=1, batch_size=FLAGS.batch_size, use_distortion_for_training=True)

			self.batch_data, self.batch_labels = tf.cond(self.am_training, 
				lambda: data_fn(subset='train'),
				lambda: data_fn(subset='test'))

			self.batch_data = self.batch_data[0]
			self.batch_labels = self.batch_labels[0]


			if len(kwargs)==0:
				self.dict_widx = None
				self._build_graph()
				self.saver = tf.train.Saver(tf.global_variables())
				# Build an initialization operation to run below
				init = tf.global_variables_initializer()
				sess.run(init)

			else:
				self.dict_widx = kwargs['dict_widx']
				pruned_model = kwargs['pruned_model_path']

				#tflearn.config.init_training_mode()
				self._build_graph()				
				#tflearn.config.init_training_mode()
				init = tf.global_variables_initializer()
				sess.run(init)

				self.saver = tf.train.Saver(tf.global_variables())
				self.saver.restore(sess, pruned_model)
				print('Pruned model restored from ', pruned_model)


			coord = tf.train.Coordinator()
			threads = tf.train.start_queue_runners(coord=coord)

			# This summary writer object helps write summaries on tensorboard
			summary_writer = tf.summary.FileWriter(FLAGS.log_dir+self.run_id)
			summary_writer.add_graph(sess.graph)

			train_error_list = []
			val_error_list = []

			print('Start training...')
			print('----------------------------------')


			train_steps_per_epoch = FLAGS.num_train_images//FLAGS.batch_size
			report_freq = train_steps_per_epoch

			train_steps = FLAGS.train_epoch * train_steps_per_epoch

			durations = []
			train_loss_list = []
			train_total_loss_list = []
			train_accuracy_list = []
			best_epoch = 0
			best_accuracy = 0
			best_loss = 1

			nparams = calculate_number_of_parameters(tf.trainable_variables())
			print(nparams)


			for step in range(train_steps):

				#print('{} step starts'.format(step))
				start_time = time.time()
				tflearn.is_training(True, session=sess)

				_, labels, _, summary_str, loss_value, total_loss, accuracy = sess.run(
					[self.batch_data, self.batch_labels, self.train_op, self.summary_op, self.loss, self.total_loss, self.accuracy], 
					feed_dict={self.am_training: True, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: FLAGS.keep_prob_conv})

				#labels, _, summary_str, loss_value, total_loss, accuracy = sess.run(
				#	[self.batch_labels, self.train_op, self.summary_op, self.loss, self.total_loss, self.accuracy], 
				#	feed_dict={self.batch_data: image_batch, self.batch_labels: label_batch, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: FLAGS.keep_prob_conv})

				tflearn.is_training(False, session=sess)
				duration = time.time() - start_time
				#print('{} step starts {}'.format(step, duration))
				
				durations.append(duration)
				train_loss_list.append(loss_value)
				train_total_loss_list.append(total_loss)
				train_accuracy_list.append(accuracy)

				assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
				
				#if step%(report_freq*10)==0:
				#	print(self.learning_rate)
				if step%report_freq == 0:
					start_time = time.time()

					summary_writer.add_summary(summary_str, step)

					sec_per_report = np.sum(np.array(durations))
					train_loss = np.mean(np.array(train_loss_list))
					train_total_loss = np.mean(np.array(train_total_loss_list))
					train_accuracy_value = np.mean(np.array(train_accuracy_list))

					train_loss_list = []
					train_total_loss_list = []
					train_accuracy_list = []
					durations = []

					train_summ = tf.Summary()
					train_summ.value.add(tag="train_loss", simple_value=train_loss.astype(np.float))
					train_summ.value.add(tag="train_total_loss", simple_value=train_total_loss.astype(np.float))
					train_summ.value.add(tag="train_accuracy", simple_value=train_accuracy_value.astype(np.float))

					summary_writer.add_summary(train_summ, step)
                                                      
					vali_loss_value, vali_accuracy_value = self._full_validation(sess)
					
					if step%(report_freq*50)==0:
						epoch = step/(report_freq*10)
						model_dir = os.path.join(FLAGS.log_dir, self.run_id, 'model')
						if not os.path.isdir(model_dir):
							os.mkdir(model_dir)
						checkpoint_path = os.path.join(model_dir, 'epoch_{}_acc_{:.3f}'.format(epoch, vali_accuracy_value))

						self.saver.save(sess, checkpoint_path, global_step=step)

					vali_summ = tf.Summary()
					vali_summ.value.add(tag="vali_loss", simple_value=vali_loss_value.astype(np.float))
					vali_summ.value.add(tag="vali_accuracy", simple_value=vali_accuracy_value.astype(np.float))


					summary_writer.add_summary(vali_summ, step)
					summary_writer.flush()

					vali_duration = time.time() - start_time

					format_str = ('Epoch %d, loss = %.4f, total_loss = %.4f, acc = %.4f, vali_loss = %.4f, val_acc = %.4f (%.3f ' 'sec/report)')
					print(format_str % (step//report_freq, train_loss, train_total_loss, train_accuracy_value, vali_loss_value, vali_accuracy_value, sec_per_report+vali_duration))
Example #8
0
def test(tf_config, test_path, probs):
    """ Test function. To evaluate the performance of trained models
    tf_config: tensorflow and gpu configurations
    test_path: url. the path of trained model for test
    probs: boolean. True: calculate the probability. False: not calcualte.
    """
    ops.reset_default_graph()
    sess = tf.Session(config=tf_config)
    save_root = FLAGS.save_root_for_prediction
    with sess.as_default():
        tflearn.is_training(False, session=sess)
        batch_data, batch_labels, subject, index = input_fn(
            data_dir=os.path.join(FLAGS.data_dir, FLAGS.set_id),
            num_shards=1,
            batch_size=FLAGS.test_batch_size,
            data_range=TEST_RANGE,
            subset='test')
        with tf.variable_scope(FLAGS.net_name):
            logits = getattr(network, FLAGS.net_name)(inputs=batch_data[0],
                                                      prob_fc=1,
                                                      prob_conv=1,
                                                      wd=0,
                                                      wd_scale=0,
                                                      training_phase=False)

        preclass = tf.nn.softmax(logits)
        prediction = tf.argmax(preclass, -1)
        label = tf.argmax(batch_labels[0], -1)
        if FLAGS.seg:
            accuracy = dice(labels=batch_labels[0],
                            logits=logits,
                            am_training=False)
        else:
            correct_prediction = tf.equal(prediction, label)
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        saver = tf.train.Saver(tf.global_variables())
        saver.restore(sess, test_path)
        print('Model restored from ', test_path)

        if probs:
            prediction_array = []
        else:
            prediction_array = np.array([])

        label_array = np.array([])

        num_batches = FLAGS.num_test_images // FLAGS.test_batch_size

        accuracy_list = []
        prediction_array = []
        label_array = []

        for step in range(num_batches):
            #print(step)
            if probs:  # TODO
                s, batch_prediction_array, batch_accuracy, batch_label_array = sess.run(
                    [batch_data, preclass, accuracy, label])
                prediction_array.append(batch_prediction_array)
                label_array = np.concatenate((label_array, batch_label_array))
                accuracy_list.append(batch_accuracy)

            else:
                images, annotations, sub_ind, ind, batch_prediction_array, batch_accuracy = sess.run(
                    [batch_data, label, subject, index, prediction, accuracy])

                accuracy_list.append(batch_accuracy)

        if probs:
            prediction_array = np.concatenate(prediction_array, axis=0)

        accuracy = np.mean(np.array(accuracy_list, dtype=np.float32))
        print('{:.3f}'.format(accuracy))
        return prediction_array, label_array, accuracy
Example #9
0
    def train(self, **kwargs):
        """ Training body.
        if the filter prune is used, the input should be:
          dict_widx: the pruned weight matrix
          pruned_model_path: the path to the pruned model.
        """
        ops.reset_default_graph()
        sess = tf.Session(config=self.tf_config)

        with sess.as_default():
            # Data Reading objects
            tflearn.is_training(True, session=sess)

            self.am_training = tf.placeholder(dtype=bool, shape=())
            self.prob_fc = tf.placeholder_with_default(0.5, shape=())
            self.prob_conv = tf.placeholder_with_default(0.5, shape=())

            data_fn = functools.partial(input_fn,
                                        data_dir=os.path.join(
                                            FLAGS.data_dir, FLAGS.set_id),
                                        num_shards=FLAGS.num_gpus,
                                        batch_size=FLAGS.batch_size,
                                        use_distortion_for_training=True)

            self.batch_data, self.batch_labels, _, _ = tf.cond(
                self.am_training,
                lambda: data_fn(data_range=self.train_range, subset='train'),
                lambda: data_fn(data_range=self.vali_range, subset='test'))

            if FLAGS.status == 'scratch':
                self.dict_widx = None
                self._build_graph()
                self.saver = tf.train.Saver(tf.global_variables())
                # Build an initialization operation to run below
                init = tf.global_variables_initializer()
                sess.run(init)

            elif FLAGS.status == 'prune':
                self.dict_widx = kwargs['dict_widx']
                pruned_model = kwargs['pruned_model_path']

                self._build_graph()
                init = tf.global_variables_initializer()
                sess.run(init)

                all_name2var = dict(
                    zip(
                        map(lambda x: x.name.split(':')[0],
                            tf.global_variables()), tf.global_variables()))
                v_sel = []

                for name, var in all_name2var.items():
                    if 'Adam' not in name:
                        v_sel.append(all_name2var[name])
                self.saver = tf.train.Saver(v_sel)
                self.saver.restore(sess, pruned_model)
                print('Pruned model restored from ', pruned_model)

            elif FLAGS.status == 'transfer':
                self._build_graph()
                init = tf.global_variables_initializer()
                sess.run(init)
                v1, v2 = get_trainable_variables(FLAGS.checkpoint_path)
                self.saver = tf.train.Saver(tf.global_variables())
                saver = tf.train.Saver(v2)
                saver.restore(sess, FLAGS.checkpoint_path)
                print('Model restored.')

            # This summary writer object helps write summaries on tensorboard
            summary_writer = tf.summary.FileWriter(FLAGS.log_dir + self.run_id)
            summary_writer.add_graph(sess.graph)

            train_error_list = []
            val_error_list = []

            print('Start training...')
            print('----------------------------------')

            train_steps_per_epoch = FLAGS.num_train_images // FLAGS.batch_size
            report_freq = train_steps_per_epoch

            train_steps = FLAGS.train_epoch * train_steps_per_epoch

            durations = []
            train_loss_list = []
            train_total_loss_list = []
            train_accuracy_list = []
            best_epoch = 0
            best_accuracy = 0
            best_loss = 1

            nparams = op_utils.calculate_number_of_parameters(
                tf.trainable_variables())
            print(nparams)

            for step in range(train_steps):

                start_time = time.time()
                tflearn.is_training(True, session=sess)

                data, labels, _, summary_str, loss_value, total_loss, accuracy = sess.run(
                    [
                        self.batch_data, self.batch_labels, self.train_op,
                        self.summary_op, self.loss, self.total_loss,
                        self.accuracy
                    ],
                    feed_dict={
                        self.am_training: True,
                        self.prob_fc: FLAGS.keep_prob_fc,
                        self.prob_conv: FLAGS.keep_prob_conv
                    })

                tflearn.is_training(False, session=sess)
                duration = time.time() - start_time
                durations.append(duration)
                train_loss_list.append(loss_value)
                train_total_loss_list.append(total_loss)
                train_accuracy_list.append(accuracy)

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step % report_freq == 0:
                    start_time = time.time()

                    summary_writer.add_summary(summary_str, step)

                    sec_per_report = np.sum(np.array(durations))
                    train_loss = np.mean(np.array(train_loss_list))
                    train_total_loss = np.mean(np.array(train_total_loss_list))
                    train_accuracy_value = np.mean(
                        np.array(train_accuracy_list))

                    train_loss_list = []
                    train_total_loss_list = []
                    train_accuracy_list = []
                    durations = []

                    train_summ = tf.Summary()
                    train_summ.value.add(tag="train_loss",
                                         simple_value=train_loss.astype(
                                             np.float))
                    train_summ.value.add(tag="train_total_loss",
                                         simple_value=train_total_loss.astype(
                                             np.float))
                    train_summ.value.add(
                        tag="train_accuracy",
                        simple_value=train_accuracy_value.astype(np.float))

                    summary_writer.add_summary(train_summ, step)

                    vali_loss_value, vali_accuracy_value = self._full_validation(
                        sess)

                    if step % (report_freq * FLAGS.save_epoch) == 0:
                        epoch = step / (report_freq * FLAGS.save_epoch)
                        model_dir = os.path.join(FLAGS.log_dir, self.run_id,
                                                 'model')
                        if not os.path.isdir(model_dir):
                            os.mkdir(model_dir)
                        checkpoint_path = os.path.join(
                            model_dir, 'epoch_{}_acc_{:.3f}'.format(
                                epoch, vali_accuracy_value))

                        self.saver.save(sess,
                                        checkpoint_path,
                                        global_step=step)

                    vali_summ = tf.Summary()
                    vali_summ.value.add(tag="vali_loss",
                                        simple_value=vali_loss_value.astype(
                                            np.float))
                    vali_summ.value.add(
                        tag="vali_accuracy",
                        simple_value=vali_accuracy_value.astype(np.float))

                    summary_writer.add_summary(vali_summ, step)
                    summary_writer.flush()

                    vali_duration = time.time() - start_time

                    format_str = (
                        'Epoch %d, loss = %.4f, total_loss = %.4f, acc = %.4f, vali_loss = %.4f, val_acc = %.4f (%.3f '
                        'sec/report)')
                    print(
                        format_str %
                        (step // report_freq, train_loss, train_total_loss,
                         train_accuracy_value, vali_loss_value,
                         vali_accuracy_value, sec_per_report + vali_duration))