def train(dir_name="lane_images"): print("Loading images...") images, labels = image_load.read_training( format_size=model.standard_format, dir_name=dir_name, grayscale=False) X = image_load.format_X(images, model.standard_format) Y = image_load.format_Y(labels) print("Training model...") cnn = model.SmallModel(input_shape=model.standard_format + (3, )) cnn.fit(X, Y) cnn.summary() print("Saving...") save_model(cnn)
def main(): parser = get_parser() args = parser.parse_args() setup(args) dataset = data.Dataset(args) test_acc = [] tf.reset_default_graph() if args.model_type == "student": teacher_model = None if args.load_teacher_from_checkpoint: teacher_model = model.BigModel(args, "teacher") teacher_model.start_session() teacher_model.load_model_from_file(args.load_teacher_checkpoint_dir) print("Verify Teacher State before Training Student") teacher_model.run_inference(dataset) student_model = model.SmallModel(args, "student") student_model.start_session() student_model.train(dataset, teacher_model) acc = student_model.inference(dataset,10,0.5,args.checkpoint_dir) print("student model acc with variation", acc) # Testing student model on the best model based on validation set student_model.start_session() student_model.load_model_from_file(args.checkpoint_dir) student_model.run_inference(dataset) if args.load_teacher_from_checkpoint: print("Verify Teacher State After Training student Model") teacher_model.run_inference(dataset) teacher_model.close_session() student_model.close_session() else: teacher_model = model.BigModel(args, "teacher") teacher_model.start_session() #Uncomment the below line if training the teacher model #teacher_model.train(dataset) # Testing teacher model on the best model based on validation set teacher_model.load_model_from_file(args.checkpoint_dir) teacher_model.run_inference(dataset) teacher_model.inference(dataset,10,0.5,args.checkpoint_dir)
def main(): paser = get_parser() args = paser.parse_args() setup(args) dataset = data.Dataset(args) tf.reset_default_graph() if args.model_type == 'student': teacher_model = None #if args.load_teacher_from_checkpoint: #teacher_model = model.BigModel(args, 'teacher') #teacher_model.start_session() #teacher_model.load_model_from_file(args.load_teacher_checkpoint_dir) #teacher_model.run_inference(dataset) student_model = model.SmallModel(args, 'student') student_model.start_session() # student_model.train(dataset,teacher_model) student_model.train(dataset, 'teacher') student_model.load_model_from_file(args.checkpoint_dir) student_model.run_inference(dataset) # if args.load_teacher_from_checkpoint: # print("Verify Teacher State After Training student Model") # teacher_model.run_inference(dataset) # teacher_model.close_session() student_model.close_session() else: print('run teacher') teacher_model = model.BigModel(args, "teacher") teacher_model.start_session() teacher_model.train(dataset) teacher_model.save_as_npy(dataset, args.batch_size) # Testing teacher model on the best model based on validation set teacher_model.load_model_from_file(args.checkpoint_dir) teacher_model.run_inference(dataset) teacher_model.close_session()
def main(): parser = get_parser() args = parser.parse_args() setup(args) # read train data train_set = utils.get_dataset(args.train_data_dir) nrof_classes = len(train_set) # read validation data print('unit test directory: %s' % args.unit_test_dir) unit_test_paths, unit_actual_issame = utils.get_val_paths( os.path.expanduser(args.unit_test_dir)) nrof_test_img = len(unit_test_paths) unit_issame_label = np.zeros(nrof_test_img) for i in range(len(unit_actual_issame)): unit_issame_label[2 * i] = unit_actual_issame[i] unit_issame_label[2 * i + 1] = unit_actual_issame[i] unit_issame_label = np.asarray(unit_issame_label, dtype=np.int32) # Get a list of image paths and their labels image_list, label_list = utils.get_image_paths_and_labels(train_set) assert len(image_list) > 0, 'The dataset should not be empty' print('Total number of train classes: %d' % nrof_classes) print('Total number of train examples: %d' % len(image_list)) print("number of validation examples: %d" % nrof_test_img) #ipdb.set_trace() train_dataset = data_loader.DataLoader(image_list, label_list, [160, 160], nrof_classes) validation_dataset = data_loader.DataLoader(unit_test_paths, unit_issame_label, [160, 160]) tf.reset_default_graph() if args.model_type == "student": teacher_model = None if args.load_teacher_from_checkpoint: teacher_model = model.BigModel(args, "teacher", nrof_classes, nrof_test_img) teacher_model.start_session() teacher_model.load_model_from_file( args.load_teacher_checkpoint_dir) print("Verify Teacher State before Training Student") teacher_model.run_inference(validation_dataset, unit_actual_issame) student_model = model.SmallModel(args, "student", nrof_classes, nrof_test_img) student_model.start_session() student_model.train(train_dataset, validation_dataset, unit_actual_issame, teacher_model) # Testing student model on the best model based on validation set student_model.load_model_from_file(args.checkpoint_dir) student_model.run_inference(validation_dataset, unit_actual_issame) if args.load_teacher_from_checkpoint: print("Verify Teacher State After Training student Model") teacher_model.run_inference(validation_dataset, unit_actual_issame) teacher_model.close_session() student_model.close_session() else: teacher_model = model.BigModel(args, "teacher", nrof_classes, nrof_test_img) teacher_model.start_session() teacher_model.train(train_dataset, validation_dataset, unit_actual_issame) # Testing teacher model on the best model based on validation set teacher_model.load_model_from_file(args.checkpoint_dir) teacher_model.run_inference(validation_dataset, unit_actual_issame) teacher_model.close_session()