Beispiel #1
0
def train():
	grasp_loader = GraspLoader.csv_loader(data_path=PARAMS.csv_path,
											csv_filename=PARAMS.csv_filename,
											save_folder=PARAMS.save_folder,
											is_saved=True,
											resize_image_size=PARAMS.image_size,
											train_subject_list=PARAMS.train_list,
											val_subject_list=PARAMS.val_list,
											test_subject_list=PARAMS.test_list,
											label_order=label_order,
											batch_size=PARAMS.batch_size,
											trans_range=PARAMS.trans_range,
											rotate_range=PARAMS.rotate_range,
											max_hue_delta=PARAMS.max_hue_delta,
											saturation_range=PARAMS.saturation_range,
											max_bright_delta=PARAMS.max_bright_delta,
											max_contrast_delta=PARAMS.max_contrast_delta,
											is_training=True
											)

	next_element, training_init_op, validation_init_op, test_init_op = \
		grasp_loader.initialization_dataset()

	batch_image, batch_label = next_element

	# Define Model
	model = taxonomy_model.taxonomy_model(inputs=batch_image,
											true_labels=batch_label,
											input_size=PARAMS.image_size,
											batch_size=PARAMS.batch_size,
											taxonomy_nums=len(grasp_loader.label_order), #len(grasp_loader.classes_numbers),
											taxonomy_classes=grasp_loader.classes_numbers,
											resnet_version=PARAMS.resnet_version,
											resnet_pretrained_path=PARAMS.resnet_path,
											resnet_exclude=PARAMS.resnet_exclude,
											trainable_scopes=PARAMS.trainable_scopes,
											extra_global_feature=True,
											taxonomy_loss=True,
											learning_rate=PARAMS.learning_rate,
											num_samples=len(grasp_loader.train_info),
											beta=PARAMS.beta,
											taxonomy_weights=[1.0, 1.0, 1.0, 1.0],
											all_label=None,
											all_value=None,
											batch_weight_range=[1.0, 1.0],
											optimizer=PARAMS.optimizer,
											is_mode='train'
											)
	all_inputs, end_point, losses, eval_value, eval_update, eval_reset = \
		model.build_model()

	train_summary_op = model.get_summary_op()

	test_summary_op = model.get_summary_op_test()

	config = tf.ConfigProto()
	# config.gpu_options.per_process_gpu_memory_fraction = 0.5
	config.gpu_options.allow_growth = True
	config.allow_soft_placement = True
	config.gpu_options.visible_device_list = PARAMS.gpu_num

	# Create a saver to save and restore all the variables.
	saver = tf.train.Saver()

	with tf.Session(config=config) as sess:

		now = datetime.datetime.now()
		folder_log = './' + 'train_%s_%s_%s_%s_%s' % (now.year, now.month, now.day, now.hour, now.minute)
		# folder_log = '.\\' + 'train_%s_%s_%s_%s_%s' % (now.year, now.month, now.day, now.hour, now.minute)
		print('folder_log: ', folder_log)
		if not os.path.exists(folder_log):
			os.makedirs(folder_log + '/code')

		# For windows
		# os.system('copy .\\*.py %s' % (folder_log))
		# For Linux
		os.system('cp ./*.py %s/code' % (folder_log))

		# remember to intiate both global and local variables for training and evaluation
		sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

		# don't forget to insert these lines
		model.resnet_restore(sess)

		summary_writer = tf.summary.FileWriter('%s' % (folder_log), sess.graph)
		total_step_num, test_step_num = 0, 0
		best_acc = 0.0

		for epoch in range(PARAMS.epochs):

			current_time = datetime.datetime.now()  # measure training time for each epoch
			# initiate the batch extraction using tf.data.Dataset
			sess.run(training_init_op)

			while (True):
				try:
					# extract batch and training
					update = sess.run(
						{'all_inputs': all_inputs, 'all_outputs': end_point, 'update_op': model.update_flag,
						 'losses': losses, 'eval_update': eval_update,
						 'summary': train_summary_op},
						feed_dict={
							# inputs: images,
							# 	true_labels: labels[:,int(-args.layer):],
							model.resnet_training_flag: False,
							model.vgg19_training_flag: True,
							model.vgg_dropout: 0.5})

					if total_step_num % PARAMS.print_freq==0:
						# print(len(update['all_inputs']), len(update['all_outputs']))
						print('losses:', update['losses'], 'accuracy (top1, top3):',
							  update['eval_update']['Accuracy_top1'], update['eval_update']['Accuracy_top3'])
						# print('losses:', update['losses'])

					summary_writer.add_summary(update['summary'], total_step_num)
					summary_writer.flush()
					total_step_num += 1

				# imgs, lbls = sess.run([images, labels])

				except tf.errors.OutOfRangeError:
					break

			print('Epoch %d done. ' % (epoch + 1))
			print('Training time: {}'.format(datetime.datetime.now() - current_time))

			# save model after each epoch
			checkpoint_file = os.path.join(folder_log, 'Grasp.ckpt')
			saver.save(sess, checkpoint_file)
			print('Model has been saved.')

			# reset all local variabels so that the streaming metrics reset new calculation
			sess.run(eval_reset)

			################################################################################
			# Validate the model with val_dataset
			current_time = datetime.datetime.now()  # measure training time for each epoch
			# initiate the batch extraction using tf.data.Dataset
			sess.run(validation_init_op)

			while (True):
				try:
					# extract batch and training
					update = sess.run({'all_inputs': all_inputs, 'all_outputs': end_point,
									   'losses': losses, 'eval_update': eval_update,
									   'summary':test_summary_op
									   },
									  feed_dict={
										  # inputs: images,
										  # 	true_labels: labels[:,int(-args.layer):],
										  model.resnet_training_flag: False,
										  model.vgg19_training_flag: False,
										  model.vgg_dropout: 1.0})

					summary_writer.add_summary(update['summary'], test_step_num)
					summary_writer.flush()
					test_step_num += 1

				except tf.errors.OutOfRangeError:
					break

			print('Validation time: {}'.format(datetime.datetime.now() - current_time))

			val_metrics = sess.run(eval_value)

			print({name: val_metrics[name]['stage_0'] for name in val_metrics.keys()})

			# convert result in to np array and save to file
			val_metrics_arr = [ [name] + [val_metrics[name][stage] for stage in val_metrics[name].keys()]
							   for name in val_metrics.keys()]
			metric_names = list(val_metrics.keys())
			stages = ['Metrics'] + list(val_metrics[metric_names[0]].keys())
			val_metrics_arr = [stages] + val_metrics_arr

			with open(folder_log + '/val_evaluation.csv', "a") as output:
				writer = csv.writer(output, lineterminator='\n')
				writer.writerows([['Epoch %s'% epoch]])
				writer.writerows(val_metrics_arr)

			# reset all local variabels so that the streaming metrics reset new calculation
			sess.run(eval_reset)

			################################################################################
			# Evaluate the model with test_dataset
			current_time = datetime.datetime.now()  # measure training time for each epoch
			# initiate the batch extraction using tf.data.Dataset
			sess.run(test_init_op)

			while (True):
				try:
					# extract batch and training
					update = sess.run({'all_inputs': all_inputs, 'all_outputs': end_point,
									   'losses': losses, 'eval_update': eval_update,
									   'summary': test_summary_op
									   },
									  feed_dict={
										  # inputs: images,
										  # 	true_labels: labels[:,int(-args.layer):],
										  model.resnet_training_flag: False,
										  model.vgg19_training_flag: False,
										  model.vgg_dropout: 1.0})

					summary_writer.add_summary(update['summary'], test_step_num)
					summary_writer.flush()
					test_step_num += 1

				except tf.errors.OutOfRangeError:
					break

			print('Test time: {}'.format(datetime.datetime.now() - current_time))

			test_metrics = sess.run(eval_value)

			print({name: test_metrics[name]['stage_0'] for name in test_metrics.keys()})

			# convert result in to np array and save to file
			test_metrics_arr = [[name] + [test_metrics[name][stage] for stage in test_metrics[name].keys()]
							   for name in test_metrics.keys()]
			metric_names = list(test_metrics.keys())
			stages = ['Metrics'] + list(test_metrics[metric_names[0]].keys())
			test_metrics_arr = [stages] + test_metrics_arr

			with open(folder_log + '/test_evaluation.csv', "a") as output:
				writer = csv.writer(output, lineterminator='\n')
				writer.writerows([['Epoch %s' % epoch]])
				writer.writerows(test_metrics_arr)

			# reset all local variabels so that the streaming metrics reset new calculation
			sess.run(eval_reset)

			if test_metrics['Accuracy_top1']['stage_%s'%(len(grasp_loader.label_order)-1)] > best_acc:
				best_acc = test_metrics['Accuracy_top1']['stage_%s'%(len(grasp_loader.label_order)-1)]
				# save trained model
				checkpoint_file = os.path.join(folder_log, 'BestGraspResnet%s.ckpt' % PARAMS.resnet_version)
				saver.save(sess, checkpoint_file)
Beispiel #2
0
def train():

    images = np.random.randn(100, 224, 224, 3)
    labels = np.random.randint(0, 10, 100).reshape((100, 1))
    classes = [10]

    batch_image = tf.constant(images, tf.float32)
    batch_label = tf.constant(labels, tf.int64)

    # Define Model
    model = taxonomy_model.taxonomy_model(
        inputs=batch_image,
        true_labels=batch_label,
        input_size=PARAMS.image_size,
        batch_size=PARAMS.batch_size,
        taxonomy_nums=1,
        taxonomy_classes=classes,
        resnet_version=PARAMS.resnet_version,
        resnet_pretrained_path=PARAMS.resnet_path,
        resnet_exclude=PARAMS.resnet_exclude,
        trainable_scopes=PARAMS.trainable_scopes,
        extra_global_feature=True,
        taxonomy_loss=True,
        learning_rate=PARAMS.learning_rate,
        num_samples=100,
        beta=PARAMS.beta,
        taxonomy_weights=[1.0, 1.0, 1.0, 1.0],
        all_label=None,
        all_value=None,
        batch_weight_range=[1.0, 1.0],
        is_mode='train')
    all_inputs, end_point, losses, eval_value, eval_update, eval_reset = \
     model.build_model()

    train_summary_op = model.get_summary_op()

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.gpu_options.visible_device_list = '1'

    # Create a saver to save and restore all the variables.
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:

        now = datetime.datetime.now()

        # remember to intiate both global and local variables for training and evaluation
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        # don't forget to insert these lines
        model.resnet_restore(sess)

        total_step_num = 0
        best_acc = 0.0

        for epoch in range(PARAMS.epochs):

            current_time = datetime.datetime.now(
            )  # measure training time for each epoch

            while (True):
                try:
                    step_time = datetime.datetime.now()

                    # imgs, lbls = sess.run([batch_image, batch_label])

                    # extract batch and training
                    update = sess.run(
                        {
                            'all_inputs': all_inputs,
                            'all_outputs': end_point,
                            'update_op': model.update_flag,
                            'losses': losses,
                            'eval_update': eval_update,
                            'summary': train_summary_op
                        },
                        feed_dict={
                            # inputs: images,
                            # 	true_labels: labels[:,int(-args.layer):],
                            model.resnet_training_flag:
                            True,
                            model.vgg19_training_flag:
                            False,
                            model.vgg_dropout:
                            0.5
                        })

                    # print(len(update['all_inputs']), len(update['all_outputs']))
                    print('losses:', update['losses'], 'accuracy:',
                          update['eval_update']['Accuracy_top1'])
                    # print('losses:', update['losses'])

                    total_step_num += 1
                    print('Step training time: {}'.format(
                        datetime.datetime.now() - step_time))

                except tf.errors.OutOfRangeError:
                    break

            print('Epoch %d done. ' % (epoch + 1))
            print('Training time: {}'.format(datetime.datetime.now() -
                                             current_time))

            # reset all local variabels so that the streaming metrics reset new calculation
            sess.run(eval_reset)
Beispiel #3
0
def train():
    grasp_loader = Grasp_csv_Loader.csv_loader(
        data_path=PARAMS.csv_path,
        csv_filename='test_annotated_data.csv',
        save_folder=PARAMS.save_folder,
        is_saved=True,
        resize_image_size=PARAMS.image_size,
        train_list=[0],
        val_list=[1],
        test_list=[1],
        label_order=label_order,
        batch_size=PARAMS.batch_size,
        max_hue_delta=PARAMS.max_hue_delta,
        saturation_range=PARAMS.saturation_range,
        max_bright_delta=PARAMS.max_bright_delta,
        max_contrast_delta=PARAMS.max_contrast_delta,
        is_training=True)

    print('dataset: ', len(grasp_loader.train_meaningful_jpg_names),
          len(grasp_loader.val_meaningful_jpg_names),
          len(grasp_loader.test_meaningful_jpg_names))

    next_element, training_init_op, validation_init_op, test_init_op = \
     grasp_loader.initialization_dataset()

    batch_image, batch_label = next_element

    # Define Model
    model = taxonomy_model.taxonomy_model(
        inputs=batch_image,
        true_labels=batch_label,
        input_size=PARAMS.image_size,
        batch_size=PARAMS.batch_size,
        taxonomy_nums=len(grasp_loader.classes_numbers),
        taxonomy_classes=grasp_loader.classes_numbers,
        resnet_version=PARAMS.resnet_version,
        resnet_pretrained_path=PARAMS.resnet_path,
        resnet_exclude=PARAMS.resnet_exclude,
        trainable_scopes=PARAMS.trainable_scopes,
        extra_global_feature=True,
        taxonomy_loss=True,
        learning_rate=PARAMS.learning_rate,
        num_samples=len(grasp_loader.train_meaningful_jpg_names),
        beta=PARAMS.beta,
        taxonomy_weights=[1.0, 1.0, 1.0, 1.0],
        all_label=None,
        all_value=None,
        batch_weight_range=[1.0, 1.0],
        is_mode='train')
    all_inputs, end_point, losses, eval_value, eval_update, eval_reset = \
     model.build_model()

    train_summary_op = model.get_summary_op()

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.gpu_options.visible_device_list = '1'  #PARAMS.gpu_num

    # Create a saver to save and restore all the variables.
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:

        now = datetime.datetime.now()
        # folder_log = './' + 'train_%s_%s_%s_%s_%s' % (now.year, now.month, now.day, now.hour, now.minute)
        # # folder_log = '.\\' + 'train_%s_%s_%s_%s_%s' % (now.year, now.month, now.day, now.hour, now.minute)
        # print('folder_log: ', folder_log)
        # if not os.path.exists(folder_log):
        # 	os.makedirs(folder_log + '/code')
        #
        # # For windows
        # # os.system('copy .\\*.py %s' % (folder_log))
        # # For Linux
        # os.system('cp ./*.py %s/code' % (folder_log))

        # remember to intiate both global and local variables for training and evaluation
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        # don't forget to insert these lines
        model.resnet_restore(sess)

        # summary_writer = tf.summary.FileWriter('%s' % (folder_log), sess.graph)
        total_step_num = 0
        best_acc = 0.0

        for epoch in range(PARAMS.epochs):

            current_time = datetime.datetime.now(
            )  # measure training time for each epoch
            # initiate the batch extraction using tf.data.Dataset
            sess.run(training_init_op)

            while (True):
                try:
                    step_time = datetime.datetime.now()

                    imgs, lbls = sess.run([batch_image, batch_label])

                    # # extract batch and training
                    # update = sess.run(
                    # 	{'all_inputs': all_inputs, 'all_outputs': end_point, 'update_op': model.update_flag,
                    # 	 'losses': losses, 'eval_update': eval_update,
                    # 	 'summary': train_summary_op},
                    # 	feed_dict={
                    # 		# inputs: images,
                    # 		# 	true_labels: labels[:,int(-args.layer):],
                    # 		model.resnet_training_flag: True,
                    # 		model.vgg19_training_flag: False,
                    # 		model.vgg_dropout: 0.5})
                    #
                    # # print(len(update['all_inputs']), len(update['all_outputs']))
                    # print('losses:', update['losses'], 'accuracy:', update['eval_update']['Accuracy_top1'])
                    # # print('losses:', update['losses'])

                    # summary_writer.add_summary(update['summary'], total_step_num)
                    # summary_writer.flush()
                    total_step_num += 1
                    print('Step training time: {}'.format(
                        datetime.datetime.now() - step_time))

                except tf.errors.OutOfRangeError:
                    break

            print('Epoch %d done. ' % (epoch + 1))
            print('Training time: {}'.format(datetime.datetime.now() -
                                             current_time))

            # reset all local variabels so that the streaming metrics reset new calculation
            sess.run(eval_reset)

            # Validate the model with val_dataset
            # initiate the batch extraction using tf.data.Dataset
            sess.run(validation_init_op)

            while (True):
                try:
                    # extract batch and training
                    update = sess.run(
                        {
                            'all_inputs': all_inputs,
                            'all_outputs': end_point,
                            'losses': losses,
                            'eval_update': eval_update
                        },
                        feed_dict={
                            # inputs: images,
                            # 	true_labels: labels[:,int(-args.layer):],
                            model.resnet_training_flag:
                            False,
                            model.vgg19_training_flag:
                            False,
                            model.vgg_dropout:
                            1.0
                        })

                    print('losses:', update['losses'], 'accuracy:',
                          update['eval_update']['Accuracy_top1'])

                except tf.errors.OutOfRangeError:
                    break

            val_metrics = sess.run(eval_value)

            # convert result in to np array and save to file
            val_metrics_arr = [[
                val_metrics[name][stage] for stage in val_metrics[name].keys()
            ] for name in val_metrics.keys()]
            print('results:\n', val_metrics_arr)

            # with open(folder_log + '/val/Epoch_%s.csv' % (epoch), "a") as output:
            # 	writer = csv.writer(output, lineterminator='\n')
            # 	writer.writerows(val_metrics_arr)

            # reset all local variabels so that the streaming metrics reset new calculation
            sess.run(eval_reset)
Beispiel #4
0
def test(folder_log, modelname):

    model_fullpath = os.path.join(folder_log, modelname)
    if os.path.isfile(model_fullpath + '.meta') is False:
        raise IsADirectoryError('No file at: ' + model_fullpath)

    grasp_loader = GraspLoader.csv_loader(
        data_path=PARAMS.csv_path,
        csv_filename=PARAMS.csv_filename,
        save_folder=PARAMS.save_folder,
        is_saved=True,
        resize_image_size=PARAMS.image_size,
        train_list=PARAMS.train_list,
        val_list=PARAMS.val_list,
        test_list=PARAMS.test_list,
        label_order=label_order,
        batch_size=PARAMS.batch_size,
        max_hue_delta=PARAMS.max_hue_delta,
        saturation_range=PARAMS.saturation_range,
        max_bright_delta=PARAMS.max_bright_delta,
        max_contrast_delta=PARAMS.max_contrast_delta,
        is_training=False)
    grasp_loader.do_preprocessing = False  # for evaluating the training set

    next_element, training_init_op, validation_init_op, test_init_op = \
     grasp_loader.initialization_dataset()

    batch_image, batch_label = next_element

    # Define Model
    model = taxonomy_model.taxonomy_model(
        inputs=batch_image,
        true_labels=batch_label,
        input_size=PARAMS.image_size,
        batch_size=PARAMS.batch_size,
        taxonomy_nums=len(
            grasp_loader.label_order),  #len(grasp_loader.classes_numbers),
        taxonomy_classes=grasp_loader.classes_numbers,
        resnet_version=PARAMS.resnet_version,
        resnet_pretrained_path=PARAMS.resnet_path,
        resnet_exclude=PARAMS.resnet_exclude,
        trainable_scopes=PARAMS.trainable_scopes,
        extra_global_feature=True,
        taxonomy_loss=True,
        learning_rate=PARAMS.learning_rate,
        num_samples=len(grasp_loader.train_meaningful_jpg_names),
        beta=PARAMS.beta,
        taxonomy_weights=[1.0, 1.0, 1.0, 1.0],
        all_label=None,
        all_value=None,
        batch_weight_range=[1.0, 1.0],
        is_mode='train')
    all_inputs, end_point, losses, eval_value, eval_update, eval_reset = \
     model.build_model()

    test_predictions = model.get_prediction_topk(k=1)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.gpu_options.visible_device_list = '1'  #PARAMS.gpu_num

    # Create a saver to save and restore all the variables.
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:

        # load model from checkpoint
        saver = tf.train.Saver()
        saver.restore(sess, model_fullpath)

        # remember to intiate both global and local variables for training and evaluation
        sess.run([tf.local_variables_initializer()])

        # don't forget to insert these lines
        # model.resnet_restore(sess)

        total_step_num = 0
        predicted_filename = folder_log + '/prediction_%s.csv' % (modelname)

        with open(predicted_filename, "a") as output:
            writer = csv.writer(output, lineterminator='\n')
            writer.writerows([['id', 'predicted']])

        current_time = datetime.datetime.now(
        )  # measure training time for each epoch

        # Test the model with test_dataset
        # initiate the batch extraction using tf.data.Dataset
        sess.run(test_init_op)  #training_init_op

        while (True):
            try:
                # extract batch and training
                update = sess.run(
                    {
                        'all_inputs': all_inputs,
                        'all_outputs': end_point,
                        'labels': batch_label,
                        'topk': test_predictions,
                        'eval_update': eval_update
                    },
                    feed_dict={
                        # inputs: images,
                        # 	true_labels: labels[:,int(-args.layer):],
                        model.resnet_training_flag:
                        False,
                        model.vgg19_training_flag:
                        False,
                        model.vgg_dropout:
                        1.0
                    })

                print('Accuracy_top1:', update['eval_update']['Accuracy_top1'],
                      'Accuracy_top3:', update['eval_update']['Accuracy_top3'])

                result_array = [[update['labels'][i], update['topk'][i]]
                                for i in range(len(update['labels']))]
                with open(predicted_filename, "a") as output:
                    writer = csv.writer(output, lineterminator='\n')
                    writer.writerows(result_array)

            except tf.errors.OutOfRangeError:
                break

        test_metrics = sess.run(eval_value)

        print({
            name: test_metrics[name]['stage_0']
            for name in test_metrics.keys()
        })
        print('Evaluation time: ', datetime.datetime.now() - current_time)

        # convert result in to np array and save to file
        test_metrics_arr = [
            [name] +
            [test_metrics[name][stage] for stage in test_metrics[name].keys()]
            for name in test_metrics.keys()
        ]
        metric_names = list(test_metrics.keys())
        stages = ['Metrics'] + list(test_metrics[metric_names[0]].keys())
        test_metrics_arr = [stages] + test_metrics_arr

        with open(folder_log + '/test_evaluation.csv', "a") as output:
            writer = csv.writer(output, lineterminator='\n')
            writer.writerows([[modelname]])
            writer.writerows(test_metrics_arr)

        # reset all local variabels so that the streaming metrics reset new calculation
        sess.run(eval_reset)
def train():
    grasp_loader = Grasp_csv_Loader_v2.csv_loader(
        data_path=PARAMS.csv_path,
        csv_filename=PARAMS.csv_filename,
        save_folder=PARAMS.save_folder,
        is_saved=True,
        resize_image_size=PARAMS.image_size,
        train_subject_list=PARAMS.train_list,
        val_subject_list=PARAMS.val_list,
        test_subject_list=PARAMS.test_list,
        divide_by_ratio=True,
        is_divided_saved=True,
        divided_npz_name='divided_dataset.npz',
        label_order=label_order,
        batch_size=PARAMS.batch_size,
        max_hue_delta=PARAMS.max_hue_delta,
        saturation_range=PARAMS.saturation_range,
        max_bright_delta=PARAMS.max_bright_delta,
        max_contrast_delta=PARAMS.max_contrast_delta,
        is_training=True)

    next_element, training_init_op, validation_init_op, test_init_op = \
     grasp_loader.initialization_dataset()

    batch_image, batch_label = next_element

    # Define Model
    model = taxonomy_model.taxonomy_model(
        inputs=batch_image,
        true_labels=batch_label,
        input_size=PARAMS.image_size,
        batch_size=PARAMS.batch_size,
        taxonomy_nums=len(grasp_loader.classes_numbers),
        taxonomy_classes=grasp_loader.classes_numbers,
        resnet_version=PARAMS.resnet_version,
        resnet_pretrained_path=PARAMS.resnet_path,
        resnet_exclude=PARAMS.resnet_exclude,
        trainable_scopes=PARAMS.trainable_scopes,
        extra_global_feature=True,
        taxonomy_loss=True,
        learning_rate=PARAMS.learning_rate,
        num_samples=len(grasp_loader.train_info),
        beta=PARAMS.beta,
        taxonomy_weights=[1.0, 1.0, 1.0, 1.0, 1.0],
        all_label=None,
        all_value=None,
        batch_weight_range=[1.0, 1.0],
        optimizer=PARAMS.optimizer,
        is_mode='train')
    all_inputs, end_point, losses, eval_value, eval_update, eval_reset = \
     model.build_model()

    train_summary_op = model.get_summary_op()

    test_summary_op = model.get_summary_op_test()

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.gpu_options.visible_device_list = PARAMS.gpu_num

    # Create a saver to save and restore all the variables.
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:

        now = datetime.datetime.now()
        folder_log = './' + 'train_%s_%s_%s_%s_%s' % (
            now.year, now.month, now.day, now.hour, now.minute)
        # folder_log = '.\\' + 'train_%s_%s_%s_%s_%s' % (now.year, now.month, now.day, now.hour, now.minute)
        print('folder_log: ', folder_log)
        if not os.path.exists(folder_log):
            os.makedirs(folder_log)

        # For windows
        # os.system('copy .\\*.py %s' % (folder_log))
        # For Linux
        os.system('cp ./*.py %s' % (folder_log))

        # remember to intiate both global and local variables for training and evaluation
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        # don't forget to insert these lines
        model.resnet_restore(sess)

        summary_writer = tf.summary.FileWriter('%s' % (folder_log), sess.graph)
        total_step_num = 0

        for epoch in range(PARAMS.epochs):

            current_time = datetime.datetime.now(
            )  # measure training time for each epoch
            # initiate the batch extraction using tf.data.Dataset
            sess.run(training_init_op)

            while (True):
                try:
                    # extract batch and training
                    each_current_time = datetime.datetime.now()
                    update = sess.run(
                        {
                            'all_inputs': all_inputs,
                            'all_outputs': end_point,
                            'update_op': model.update_flag,
                            'losses': losses,
                            'eval_update': eval_update,
                            'summary': train_summary_op
                        },
                        feed_dict={
                            # inputs: images,
                            # 	true_labels: labels[:,int(-args.layer):],
                            model.resnet_training_flag:
                            False,
                            model.vgg19_training_flag:
                            True,
                            model.vgg_dropout:
                            0.5
                        })
                    # print(len(update['all_inputs']), len(update['all_outputs']))
                    print('losses:', update['losses'],
                          'accuracy (top1, top3):',
                          update['eval_update']['Accuracy_top1'],
                          update['eval_update']['Accuracy_top3'])
                    # print('losses:', update['losses'])
                    print('Each iteration time: {}'.format(
                        datetime.datetime.now() - each_current_time))

                    summary_writer.add_summary(update['summary'],
                                               total_step_num)
                    summary_writer.flush()
                    total_step_num += 1

                except tf.errors.OutOfRangeError:
                    break

            print('Epoch %d done. ' % (epoch + 1))
            print('Training time: {}'.format(datetime.datetime.now() -
                                             current_time))

            checkpoint_file = os.path.join(folder_log, 'Grasp.ckpt')
            saver.save(sess, checkpoint_file, global_step=epoch)

            # reset all local variabels so that the streaming metrics reset new calculation
            sess.run(eval_reset)

            if (epoch % PARAMS.epoch_decay) == 0:

                current_time = datetime.datetime.now(
                )  # measure training time for each epoch
                # initiate the batch extraction using tf.data.Dataset
                sess.run(validation_init_op)

                while (True):
                    try:
                        # extract batch and training
                        update = sess.run(
                            {
                                'all_inputs': all_inputs,
                                'all_outputs': end_point,
                                'losses': losses,
                                'eval_update': eval_update,
                                'summary': test_summary_op
                            },
                            feed_dict={
                                # inputs: images,
                                # 	true_labels: labels[:,int(-args.layer):],
                                model.resnet_training_flag:
                                False,
                                model.vgg19_training_flag:
                                False,
                                model.vgg_dropout:
                                1.0
                            })
                        # print(len(update['all_inputs']), len(update['all_outputs']))
                        print('losses:', update['losses'],
                              'accuracy (top1, top3):',
                              update['eval_update']['Accuracy_top1'],
                              update['eval_update']['Accuracy_top3'])
                        # print('losses:', update['losses'])
                        print('Each iteration time: {}'.format(
                            datetime.datetime.now() - current_time))

                        summary_writer.add_summary(update['summary'], epoch)
                        summary_writer.flush()

                    except tf.errors.OutOfRangeError:
                        break

                print('Validation time: {}'.format(datetime.datetime.now() -
                                                   current_time))

            # reset all local variabels so that the streaming metrics reset new calculation
            sess.run(eval_reset)