Example #1
0
 def finalize(self, shapes, batchsize):
     assert not hasattr(self, 'output_shapes')
     self.output_shapes = shapes
     self.batchsize = batchsize
     ishapes = [[None] + list(s) for s in self.output_shapes]
     itypes = (tf.float32, tf.float32)
     handle = tf.placeholder(tf.string, shape=[])
     iterator = Iterator.from_string_handle(handle, itypes, tuple(ishapes))
     getter = iterator.get_next()
     return getter, handle
Example #2
0
def get_dataset_tensors(args):
    with tf.device('/cpu:0'), tf.variable_scope('input_pipeline'):
        # TODO move this to hem.init()
        # find all dataset plugins available
        p = get_dataset(args.dataset)
        # ensure that the dataset exists
        if not p.check_prepared_datasets(args.dataset_dir):
            if not p.check_raw_datasets(args.raw_dataset_dir):
                print('Downloading dataset...')
                # TODO: datasets should be able to be marked as non-downloadable
                p.download(args.raw_dataset_dir)
            print('Converting to tfrecord...')
            p.convert_to_tfrecord(args.raw_dataset_dir, args.dataset_dir)

        # load the dataset
        datasets = p.get_datasets(args)
        dataset_iterators = {}
        # tensor to hold which training/eval phase we are in
        handle = tf.placeholder(tf.string, shape=[])
        # add a dataset for train, validation, and testing
        for k, v in datasets.items():
            # skip test set if not needed
            if len(args.test_epochs) == 0 and k == 'test':
                continue
            d = v[0]
            n = sum([1 for r in tf.python_io.tf_record_iterator(v[1])])
            cache_fn = '{}.cache.{}'.format(args.dataset, k)
            d = d.cache(os.path.join(
                args.cache_dir, cache_fn)) if args.cache_dir else d.cache()
            d = d.repeat()
            d = d.shuffle(buffer_size=args.buffer_size, seed=args.seed)
            d = d.batch(args.batch_size * args.n_gpus)
            x_iterator = d.make_initializable_iterator()
            dataset_iterators[k] = {
                'x': x_iterator,
                'n': n,
                'batches': int(n / (args.batch_size * args.n_gpus)),
                'handle': x_iterator.string_handle()
            }
        # feedable dataset that will swap between train/test/val
        iterator = Iterator.from_string_handle(handle, d.output_types,
                                               d.output_shapes)
        return iterator.get_next(), handle, dataset_iterators
def train(hps, design):
    """Training loop."""
    train_records = _get_tfrecord_files_from_dir(
        FLAGS.train_data_path)  #get tfrecord files for train
    train_iterator = petct_input.build_input(train_records, hps.batch_size,
                                             hps.num_epochs, FLAGS.mode)
    train_iterator_handle = train_iterator.string_handle()

    if not FLAGS.val_data_path == '':  # skip validation if no path
        val_records = _get_tfrecord_files_from_dir(
            FLAGS.val_data_path)  # get tfrecord files for val
        val_iterator = petct_input.build_input(val_records, hps.batch_size,
                                               hps.num_epochs, 'valid')
        val_iterator_handle = val_iterator.string_handle()

    handle = tf.placeholder(tf.string, shape=[], name='data')
    iterator = Iterator.from_string_handle(handle, train_iterator.output_types,
                                           train_iterator.output_shapes)
    ct, pt, ctlb, ptlb, bglb = iterator.get_next()

    model = fuse_cnn_petct.FuseNet(hps, design, ct, pt, ctlb, ptlb, bglb,
                                   FLAGS.mode)
    model.build_cross_modal_model()

    # for use in loading later
    #tf.get_collection('model')
    #tf.add_to_collection('model',model)

    # put get metrics ops here for train and val
    with tf.variable_scope('metrics'):
        tr_summary_op, tr_precision_op, tr_recall_op, tr_accuracy_op, tr_rmse_op = get_metrics_ops(
            model, 'train')
        val_summary_op, val_precision_op, val_recall_op, val_accuracy_op, val_rmse_op = get_metrics_ops(
            model, 'valid')

    # needed for input handlers
    g_init_op = tf.global_variables_initializer()
    l_init_op = tf.local_variables_initializer()

    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True, device_count={'GPU': 1})) as mon_sess:
        # Need a saver to save and restore all the variables.
        saver = tf.train.Saver()

        if FLAGS.DEBUG:
            print('ENABLING DEBUG')
            mon_sess = tf_debug.LocalCLIDebugWrapperSession(mon_sess)
            mon_sess.add_tensor_filter("has_inf_or_nan",
                                       tf_debug.has_inf_or_nan)

        training_handle = mon_sess.run(train_iterator_handle)
        if not FLAGS.val_data_path == '':  # skip validation if no path
            validation_handle = mon_sess.run(val_iterator_handle)

        train_writer = tf.summary.FileWriter(FLAGS.log_root + '/train',
                                             mon_sess.graph)

        if not FLAGS.val_data_path == '':  # skip validation if no path
            valid_writer = tf.summary.FileWriter(FLAGS.log_root + '/valid')

        mon_sess.run([g_init_op, l_init_op])

        summary = None
        step = None
        val_summary = None
        #check = 1
        while True:
            try:
                ## FIRST RUN TRAINING OP BASED ON OUTPUT STYLE
                if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT:
                    # get PET and CT recons separately
                    _, summary, step, loss, p, r, a, e, cts, pts, trcts, trpts, trbgs, recon_cts, recon_pts, ct_preds, pt_preds = mon_sess.run(
                        [
                            model.train_op, tr_summary_op, model.global_step,
                            model.cost, tr_precision_op, tr_recall_op,
                            tr_accuracy_op, tr_rmse_op, model.ct, model.pt,
                            model.lbct, model.lbpt, model.lbbg, model.ct_pred,
                            model.pt_pred, model.ct_probabilities,
                            model.pt_probabilities
                        ],
                        feed_dict={
                            handle: training_handle,
                            model.is_training: True
                        })
                elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE:
                    # get PET and CT recons together
                    _, summary, step, loss, p, r, a, e, cts, pts, trallpos, trbgs, recon_all, all_preds = mon_sess.run(
                        [
                            model.train_op, tr_summary_op, model.global_step,
                            model.cost, tr_precision_op, tr_recall_op,
                            tr_accuracy_op, tr_rmse_op, model.ct, model.pt,
                            model.lb_pos_gt, model.lbbg, model.all_pred,
                            model.all_probabilities
                        ],
                        feed_dict={
                            handle: training_handle,
                            model.is_training: True
                        })

                if step % FLAGS.train_iter == 0:
                    print(
                        '[TRAIN] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f'
                        % (step, loss, p, r, a, e))
                    train_writer.add_summary(summary, step)
                    train_writer.flush()

                if FLAGS.IMSAVE > 0:
                    if step % FLAGS.IMSAVE == 0:
                        print('SAVING IMAGES')
                        if FLAGS.output_style == fuse_cnn_petct.STYLE_SPLIT:
                            _saveImages(hps.batch_size,
                                        step,
                                        cts,
                                        pts,
                                        trcts=trcts,
                                        trpts=trpts,
                                        trbgs=trbgs,
                                        recon_cts=recon_cts,
                                        recon_pts=recon_pts,
                                        ct_preds=ct_preds,
                                        pt_preds=pt_preds)
                        elif FLAGS.output_style == fuse_cnn_petct.STYLE_SINGLE:
                            _saveImages(hps.batch_size,
                                        step,
                                        cts,
                                        pts,
                                        trallpos=trallpos,
                                        trbgs=trbgs,
                                        recon_all=recon_all,
                                        all_preds=all_preds)

                if not FLAGS.val_data_path == '':  # skip validation if no path
                    if step % FLAGS.val_iter == 0:
                        _, val_summary, loss, p, r, a, e = mon_sess.run(
                            [
                                model.val_op, val_summary_op, model.cost,
                                val_precision_op, val_recall_op,
                                val_accuracy_op, val_rmse_op
                            ],
                            feed_dict={
                                handle: validation_handle,
                                model.is_training: False
                            })
                        val_step = step
                        print(
                            '[VALID] STEP: %d, LOSS: %.5f, PRECISION: %.5f, RECALL: %.5f, ACCURACY: %.5f, RMSE: %.5f'
                            % (step, loss, p, r, a, e))
                        valid_writer.add_summary(val_summary, step)
                        valid_writer.flush()

                if step % FLAGS.chkpt_iter == 0:
                    save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str(
                        step) + '.ckpt'
                    save_path = saver.save(mon_sess, save_loc)
                    print('Model saved in path: %s' % save_path)

            except tf.errors.OutOfRangeError:
                print('OUT OF DATA - ENDING')
                # now finished training (either train or validation has run out)
                train_writer.add_summary(summary, step)
                train_writer.flush()
                if not FLAGS.val_data_path == '':  # skip validation if no path
                    valid_writer.add_summary(val_summary, val_step)
                    valid_writer.flush()
                save_loc = FLAGS.log_root + '/' + FLAGS.chkpt_file + str(
                    step) + '-end.ckpt'
                save_path = saver.save(mon_sess, save_loc)
                print('Model saved in path: %s' % save_path)
                break
Example #4
0

if __name__ == '__main__':
    batch_size = 5
    attr_label_num = 30
    img_loader = image_bleach.ImageLoader(
        [global_configs.pic2attr_tfrecord_train_path], batch_size, 100, 40000)
    train_dataset = img_loader.launch_tfrecord_dataset()
    train_iterator = train_dataset.make_one_shot_iterator()

    # ========================= 数据导入 =========================

    # =================== 用handle导入,feedble ===================
    # 构造一个可导入(feedble)的句柄占位符,可以通过这个将训练集的句柄或者验证集的句柄传入
    handle = tf.placeholder(tf.string, shape=[])
    iterator = Iterator.from_string_handle(handle, train_dataset.output_types,
                                           train_dataset.output_shapes)
    pic_name_batch, pic_class_batch, attr_label_batch, img_batch = iterator.get_next(
    )
    # 从迭代器中出来的是一个二维数组,而用到的id、effect_len和label是要一个一维数组,需要reshape以下
    pic_name_batch = tf.reshape(pic_name_batch, [batch_size])
    pic_class_batch = tf.reshape(pic_class_batch, [batch_size])
    attr_label_batch = tf.reshape(attr_label_batch,
                                  [batch_size, attr_label_num])
    # ==================/ 用handle导入,feedble /==================

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        sess.run(tf.global_variables_initializer())  # 所有变量初始化
        # 获得训练集和验证集的引用句柄,后面导入数据到模型用
	def __init__(self,
					sess,
					batch_size,
					stage_of_development,
					learning_rate_decay_factor,
					type_of_model,
					summary_dir,
					experiment_folder, 
					type_of_optimizer,
					num_of_classes,
					total_num_of_training_examples,
					dropout=0.5,
					model_path=None,
					beta1=0.9,
					beta2=0.999,
					min_bins=None,
					max_bins=None,
					list_of_tfrecords_for_training=None,
					list_of_tfrecords_for_evaluation=None,
					training_with_eval=False,
					dict_of_filePath_to_num_of_examples_in_tfrecord=None):

		self.training_batch_size = 0
		self.batch_size = batch_size
		self.list_of_tr_datasets = []
		self.list_of_eval_datasets = []
		print("Training with dev", training_with_eval)
		print(sorted(list(dict_of_filePath_to_num_of_examples_in_tfrecord.keys())))

		if stage_of_development == "training":
			for tfrecord_for_training_example_ in list_of_tfrecords_for_training:
				current_tr_data = tf.contrib.data.TFRecordDataset(tfrecord_for_training_example_)
				if type_of_model == 'VGG':
					current_tr_data = current_tr_data.map(parse_example_vgg)
				elif type_of_model == 'ResNet':
					current_tr_data = current_tr_data.map(parse_example_ResNet)
				else:
					current_tr_data = current_tr_data.map(parse_example)
				current_tr_data = current_tr_data.shuffle(buffer_size=20000)
				current_tr_data = current_tr_data.repeat()
				current_tfrecord_batch_size = math.ceil(((float(dict_of_filePath_to_num_of_examples_in_tfrecord[tfrecord_for_training_example_]) * 1.0) / (float(total_num_of_training_examples) * 1.0)) * self.batch_size)
				self.training_batch_size += current_tfrecord_batch_size
				current_tr_data = current_tr_data.batch(current_tfrecord_batch_size)
				print(tfrecord_for_training_example_, dict_of_filePath_to_num_of_examples_in_tfrecord[tfrecord_for_training_example_], current_tfrecord_batch_size)
				self.list_of_tr_datasets.append(current_tr_data)

		if stage_of_development == "training":
			self.batch_size = self.training_batch_size

		self.single_eval_data = None
		if stage_of_development != "training":
			self.single_eval_data = tf.contrib.data.TFRecordDataset(list_of_tfrecords_for_evaluation)
			if type_of_model == 'VGG':
				self.single_eval_data = self.single_eval_data.map(parse_example_vgg)
			elif type_of_model == 'ResNet':
				self.single_eval_data = self.singe_eval_data.map(parse_example_ResNet)
			else:
				self.single_eval_data = self.single_eval_data.map(parse_example)
			self.single_eval_data = self.single_eval_data.shuffle(buffer_size=10000)
			self.single_eval_data = self.single_eval_data.repeat(1)
			self.single_eval_data = self.single_eval_data.batch(self.batch_size)


			#for tfrecord_for_evaluation_example_ in list_of_tfrecords_for_evaluation:
			#	current_eval_data = tf.contrib.data.TFRecordDataset(tfrecord_for_evaluation_example_)
			#	if type_of_model == 'VGG':
			#		current_eval_data = current_eval_data.map(parse_example_vgg)
			#	elif type_of_model == 'ResNet':
			#		current_eval_data = current_eval_data.map(parse_example_ResNet)
			#	else:
			#		current_eval_data = current_eval_data.map(parse_example)
			#	current_eval_data = current_eval_data.shuffle(buffer_size=10000)
			#	current_eval_data = current_eval_data.repeat(1)
			#	current_eval_data = current_eval_data.batch(self.batch_size)
			#	self.list_of_eval_datasets.append(current_eval_data)

		self.list_of_handles = []
		self.list_of_iterators = []
		self.list_of_batch_imgs = []
		self.list_of_batch_labels = []
		self.list_of_batch_imgs_and_batch_labels = []

		if stage_of_development == "training":
			for idx_ in range(len(list_of_tfrecords_for_training)):
				self.list_of_handles.append(tf.placeholder(tf.string, shape=[]))
				self.list_of_iterators.append(Iterator.from_string_handle(self.list_of_handles[idx_], self.list_of_tr_datasets[0].output_types, self.list_of_tr_datasets[0].output_shapes))
				batched_imgs, batched_labels = self.list_of_iterators[idx_].get_next()
				if type_of_model == 'DehazeNet':
					self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224+15, 224+15, 3]))
				else:
					self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224, 224, 3]))
				self.list_of_batch_labels.append(tf.reshape(batched_labels, [-1, 1]))
		else:
			self.single_eval_handle = tf.placeholder(tf.string, shape=[])
			self.single_eval_iterator = Iterator.from_string_handle(self.single_eval_handle, self.single_eval_data.output_types, self.single_eval_data.output_shapes)
			self.eval_batched_imgs, self.eval_batched_labels = self.single_eval_iterator.get_next()
			if type_of_model == 'DehazeNet':
				self.eval_batched_imgs = tf.reshape(self.eval_batched_imgs, [-1, 224+15, 224+15, 3])
			else:
				self.eval_batched_imgs = tf.reshape(self.eval_batched_imgs, [-1, 224, 224, 3])
			self.eval_batched_labels = tf.reshape(self.eval_batched_labels, [-1, 1])


			#for idx_ in range(len(list_of_tfrecords_for_evaluation)):
			#	self.list_of_handles.append(tf.placeholder(tf.string, shape=[]))
			#	self.list_of_iterators.append(Iterator.from_string_handle(self.list_of_handles[idx_], self.list_of_eval_datasets[0].output_types, self.list_of_eval_datasets[0].output_shapes))
			#	batched_imgs, batched_labels = self.list_of_iterators[idx_].get_next()
			#	if type_of_model == 'DehazeNet':
			#		self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224+15, 224+15, 3]))
			#	else:
			#		self.list_of_batch_imgs.append(tf.reshape(batched_imgs, [-1, 224, 224, 3]))
			#	self.list_of_batch_labels.append(tf.reshape(batched_labels, [-1, 1]))

		self.list_of_training_iterators = []
		self.single_eval_iterator = None

		if stage_of_development == "training":
			for tr_dataset_example_ in self.list_of_tr_datasets:
				validation_iterator = tr_dataset_example_.make_one_shot_iterator()
				self.list_of_training_iterators.append(validation_iterator)

		if stage_of_development == "evaluation":
			self.single_eval_iterator = self.single_eval_data.make_one_shot_iterator()

		self.row_indices = tf.placeholder(tf.int32, (self.batch_size,))
		self.row_indices_reshaped = tf.reshape(self.row_indices, [self.batch_size, 1])


		if stage_of_development == "training":
			self.batch_inputs = tf.gather_nd(tf.concat(self.list_of_batch_imgs, 0), self.row_indices_reshaped)
			self.batch_targets = tf.gather_nd(tf.concat(self.list_of_batch_labels, 0), self.row_indices_reshaped)
		else:
			self.batch_inputs = tf.gather_nd(self.eval_batch_imgs, self.row_indices_reshaped)
			self.batch_targets = tf.gather_nd(self.eval_batched_labels, self.row_indices_reshaped)

		self.stage_of_development = stage_of_development
		self.model_path = model_path
		self.type_of_optimizer = type_of_optimizer
		self.model = None
		self.beta1 = beta1
		self.beta2 = beta2
		self.pm_values = tf.gather_nd(tf.concat(self.list_of_batch_labels, 0), self.row_indices_reshaped)

		#if num_of_classes > 1:
		#	discrete_targets = tf.cast(self.batch_targets, dtype=tf.float32)
		#	discrete_targets = tf.reshape(discrete_targets, [-1, 1])
		#	min_bins = tf.reshape(tf.cast(min_bins, dtype=tf.float32), [1, -1])
		#	max_bins = tf.reshape(tf.cast(max_bins, dtype=tf.float32), [1, -1])
		#	c_1 =  tf.subtract(discrete_targets, min_bins)
		#	c_1 = tf.add(tf.cast(c_1 < 0, c_1.dtype) * 10000, tf.nn.relu(c_1))
		#	c_2 =  tf.subtract(discrete_targets * -1, max_bins)
		#	c_2 = tf.add(tf.cast(c_2 < 0, c_2.dtype) * 10000, tf.nn.relu(c_2))
		#	c = tf.add(c_1, c_2)
		#	self.batch_targets = tf.reshape(tf.argmin(c, 1), [-1, 1])

		self.is_training = tf.placeholder(tf.bool, shape=[])

		if type_of_model == 'DehazeNet':
			self.model = DehazeNetModel(sess,
										self.batch_inputs,
										self.batch_targets,
										self.stage_of_development,
										num_of_classes,
										min_bins=min_bins,
										max_bins=max_bins)
		elif type_of_model == "VGG":
			self.model = AQPVGGModel(sess,
									self.batch_inputs,
									self.batch_targets,
									self.stage_of_development,
									self.model_path,
									num_of_classes,
									self.is_training,
									min_bins=min_bins,
									max_bins=max_bins)
		elif type_of_model == "ResNet":
			self.model = AQPResNetModel(sess,
										self.batch_inputs,
										self.batch_targets,
										self.stage_of_development,
										self.model_path,
										num_of_classes,
										min_bins=min_bins,
										max_bins=max_bins)    
		else:
			self.model = SimpleCNNModel(sess, self.batch_inputs, self.batch_targets, self.stage_of_development)

		def return_predictions():
			return self.model.predictions
		def return_validation_predictions():
			return self.model.validation_predictions

		def return_MAE():
			return tf.reduce_mean(tf.abs(tf.subtract(self.model.predictions, self.model.labels)))
		def return_validation_MAE():
			return tf.reduce_mean(tf.abs(tf.subtract(self.model.validation_predictions, self.model.labels)))

		def return_MSE():
			return tf.reduce_mean(tf.square(tf.subtract(self.model.predictions, self.model.labels)))
		def return_validation_MSE():
			return tf.reduce_mean(tf.square(tf.subtract(self.model.validation_predictions, self.model.labels)))

		def return_MSLE():
			return tf.reduce_mean(tf.square(tf.subtract(tf.log(tf.add(self.model.predictions, 1.0)), tf.log(tf.add(self.model.labels, 1.0)))))
		def return_validation_MSLE():
			return tf.reduce_mean(tf.square(tf.subtract(tf.log(tf.add(self.model.validation_predictions, 1.0)), tf.log(tf.add(self.model.labels, 1.0)))))

		def return_R2_score():
			numerator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, self.model.predictions))) #Unexplained Error
			denominator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, tf.reduce_mean(self.model.labels)))) # Total Error
			return tf.subtract(1.0, tf.divide(numerator, denominator))
		def return_validation_R2_score():
			numerator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, self.model.validation_predictions))) #Unexplained Error
			denominator = tf.reduce_sum(tf.square(tf.subtract(self.model.labels, tf.reduce_mean(self.model.labels)))) # Total Error
			return tf.subtract(1.0, tf.divide(numerator, denominator))

		self.learning_rate = tf.placeholder(tf.float32, shape=[])
		self.partial_learning_rate = tf.placeholder(tf.float32, shape=[])
		self.global_step = tf.Variable(0, trainable=False)
		self.predictions  = tf.cond(self.is_training, return_predictions, return_validation_predictions)
		self.MAE_ = tf.cond(self.is_training, return_MAE, return_validation_MAE)
		self.R2_score_ = tf.cond(self.is_training, return_R2_score, return_validation_R2_score)
		self.MSE_ = tf.cond(self.is_training, return_MSE, return_validation_MSE)
		self.MSLE_ = tf.cond(self.is_training, return_MSLE, return_validation_MSLE)
		
		if self.stage_of_development == "training":
			self.global_eval_step = tf.Variable(0, trainable=False)
			self.global_eval_update_step_variable = tf.assign(self.global_eval_step, self.global_eval_step+1)
			tf.summary.scalar('MAE', self.MAE_)
			tf.summary.scalar('MSE', self.MSE_)
			tf.summary.scalar('MSLE', self.MSLE_)
			tf.summary.scalar('R2 Coefficient', self.R2_score_)

		if self.stage_of_development == "training" or self.stage_of_development == "resume_training":
			if type_of_model == 'DehazeNet' or type_of_model == 'VGG':
				partial_opt = None
				if self.type_of_optimizer == 'adam':
					partial_opt = tf.train.AdamOptimizer(learning_rate=self.partial_learning_rate, beta1=self.beta1, beta2=self.beta2)
				else:
					partial_opt = tf.train.GradientDescentOptimizer(learning_rate=self.partial_learning_rate)
				partial_gradient = tf.gradients(self.MAE_, self.model.variables_trained_from_scratch)
				self.partial_train_op = partial_opt.apply_gradients(zip(partial_gradient, self.model.variables_trained_from_scratch), global_step=self.global_step)

				if self.type_of_optimizer == 'adam':
					full_opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=self.beta2)
				else:
					full_opt = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate)
				full_gradient = tf.gradients(self.MAE_, self.model.all_variables)
				self.train_op = full_opt.apply_gradients(zip(full_gradient, self.model.all_variables), global_step=self.global_step)
			elif type_of_model == 'ResNet':
				partial_opt = tf.train.AdamOptimizer(learning_rate=self.partial_learning_rate, beta1=self.beta1, beta2=self.beta2)
				self.partial_train_op = slim.learning.create_train_op(self.MAE_, partial_opt, global_step=self.global_step, variables_to_train=self.model.variables_trained_from_scratch)

				full_opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=self.beta2)
				self.train_op = slim.learning.create_train_op(self.MAE_, full_opt, global_step=self.global_step, variables_to_train=self.model.all_variables)
			else:
				if self.type_of_optimizer == 'adam':
					opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=self.beta2)
				else:
					opt = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate)
				gradient = tf.gradients(self.MAE_, self.model.all_variables)
				self.train_op = opt.apply_gradients(zip(gradient, self.model.all_variables), global_step=self.global_step)

		self.merged = tf.summary.merge_all()
		self.train_writer = tf.summary.FileWriter(summary_dir +  '/' + experiment_folder + '/train', sess.graph)
		self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=4)