示例#1
0
def train(dir_name="lane_images"):
    print("Loading images...")
    images, labels = image_load.read_training(
        format_size=model.standard_format, dir_name=dir_name, grayscale=False)

    X = image_load.format_X(images, model.standard_format)
    Y = image_load.format_Y(labels)

    print("Training model...")
    cnn = model.SmallModel(input_shape=model.standard_format + (3, ))
    cnn.fit(X, Y)
    cnn.summary()

    print("Saving...")
    save_model(cnn)
示例#2
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    setup(args)
    dataset = data.Dataset(args)
    test_acc = []

    tf.reset_default_graph()
    if args.model_type == "student":
        teacher_model = None
        if args.load_teacher_from_checkpoint:
            teacher_model = model.BigModel(args, "teacher")
            teacher_model.start_session()
            teacher_model.load_model_from_file(args.load_teacher_checkpoint_dir)
            print("Verify Teacher State before Training Student")
            teacher_model.run_inference(dataset)
        student_model = model.SmallModel(args, "student")
        student_model.start_session()
        student_model.train(dataset, teacher_model)
        acc = student_model.inference(dataset,10,0.5,args.checkpoint_dir)
        print("student model acc with variation", acc)

        # Testing student model on the best model based on validation set
        student_model.start_session()
        student_model.load_model_from_file(args.checkpoint_dir)
        student_model.run_inference(dataset)


        if args.load_teacher_from_checkpoint:
            print("Verify Teacher State After Training student Model")
            teacher_model.run_inference(dataset)
            teacher_model.close_session()
        student_model.close_session()
    else:
        teacher_model = model.BigModel(args, "teacher")
        teacher_model.start_session()

        #Uncomment the below line if training the teacher model
        #teacher_model.train(dataset)

        # Testing teacher model on the best model based on validation set
        teacher_model.load_model_from_file(args.checkpoint_dir)
        teacher_model.run_inference(dataset)
        teacher_model.inference(dataset,10,0.5,args.checkpoint_dir)
示例#3
0
def main():
    paser = get_parser()
    args = paser.parse_args()
    setup(args)
    dataset = data.Dataset(args)

    tf.reset_default_graph()
    if args.model_type == 'student':
        teacher_model = None
        #if args.load_teacher_from_checkpoint:
        #teacher_model = model.BigModel(args, 'teacher')
        #teacher_model.start_session()
        #teacher_model.load_model_from_file(args.load_teacher_checkpoint_dir)
        #teacher_model.run_inference(dataset)
        student_model = model.SmallModel(args, 'student')
        student_model.start_session()
        # student_model.train(dataset,teacher_model)
        student_model.train(dataset, 'teacher')
        student_model.load_model_from_file(args.checkpoint_dir)
        student_model.run_inference(dataset)

        # if args.load_teacher_from_checkpoint:
        #     print("Verify Teacher State After Training student Model")
        #     teacher_model.run_inference(dataset)
        #     teacher_model.close_session()
        student_model.close_session()
    else:
        print('run teacher')
        teacher_model = model.BigModel(args, "teacher")
        teacher_model.start_session()
        teacher_model.train(dataset)
        teacher_model.save_as_npy(dataset, args.batch_size)

        # Testing teacher model on the best model based on validation set
        teacher_model.load_model_from_file(args.checkpoint_dir)
        teacher_model.run_inference(dataset)
        teacher_model.close_session()
示例#4
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    setup(args)
    # read train data
    train_set = utils.get_dataset(args.train_data_dir)
    nrof_classes = len(train_set)

    # read validation data
    print('unit test directory: %s' % args.unit_test_dir)
    unit_test_paths, unit_actual_issame = utils.get_val_paths(
        os.path.expanduser(args.unit_test_dir))
    nrof_test_img = len(unit_test_paths)
    unit_issame_label = np.zeros(nrof_test_img)
    for i in range(len(unit_actual_issame)):
        unit_issame_label[2 * i] = unit_actual_issame[i]
        unit_issame_label[2 * i + 1] = unit_actual_issame[i]
    unit_issame_label = np.asarray(unit_issame_label, dtype=np.int32)
    # Get a list of image paths and their labels
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    assert len(image_list) > 0, 'The dataset should not be empty'

    print('Total number of train classes: %d' % nrof_classes)
    print('Total number of train examples: %d' % len(image_list))
    print("number of validation examples: %d" % nrof_test_img)
    #ipdb.set_trace()
    train_dataset = data_loader.DataLoader(image_list, label_list, [160, 160],
                                           nrof_classes)
    validation_dataset = data_loader.DataLoader(unit_test_paths,
                                                unit_issame_label, [160, 160])
    tf.reset_default_graph()
    if args.model_type == "student":
        teacher_model = None
        if args.load_teacher_from_checkpoint:
            teacher_model = model.BigModel(args, "teacher", nrof_classes,
                                           nrof_test_img)
            teacher_model.start_session()
            teacher_model.load_model_from_file(
                args.load_teacher_checkpoint_dir)
            print("Verify Teacher State before Training Student")
            teacher_model.run_inference(validation_dataset, unit_actual_issame)
        student_model = model.SmallModel(args, "student", nrof_classes,
                                         nrof_test_img)
        student_model.start_session()
        student_model.train(train_dataset, validation_dataset,
                            unit_actual_issame, teacher_model)

        # Testing student model on the best model based on validation set
        student_model.load_model_from_file(args.checkpoint_dir)
        student_model.run_inference(validation_dataset, unit_actual_issame)

        if args.load_teacher_from_checkpoint:
            print("Verify Teacher State After Training student Model")
            teacher_model.run_inference(validation_dataset, unit_actual_issame)
            teacher_model.close_session()
        student_model.close_session()
    else:
        teacher_model = model.BigModel(args, "teacher", nrof_classes,
                                       nrof_test_img)
        teacher_model.start_session()
        teacher_model.train(train_dataset, validation_dataset,
                            unit_actual_issame)

        # Testing teacher model on the best model based on validation set
        teacher_model.load_model_from_file(args.checkpoint_dir)
        teacher_model.run_inference(validation_dataset, unit_actual_issame)
        teacher_model.close_session()