示例#1
0
def main():
    # build train program
    train_build_outputs = program.build(
        config, train_program, startup_program, mode='train')
    train_loader = train_build_outputs[0]
    train_fetch_name_list = train_build_outputs[1]
    train_fetch_varname_list = train_build_outputs[2]
    train_opt_loss_name = train_build_outputs[3]
    model_average = train_build_outputs[-1]

    # build eval program
    eval_program = fluid.Program()
    eval_build_outputs = program.build(
        config, eval_program, startup_program, mode='eval')
    eval_fetch_name_list = eval_build_outputs[1]
    eval_fetch_varname_list = eval_build_outputs[2]
    eval_program = eval_program.clone(for_test=True)

    # initialize train reader
    train_reader = reader_main(config=config, mode="train")
    train_loader.set_sample_list_generator(train_reader, places=place)

    # initialize eval reader
    eval_reader = reader_main(config=config, mode="eval")

    exe = fluid.Executor(place)
    exe.run(startup_program)

    # compile program for multi-devices
    train_compile_program = program.create_multi_devices_program(
        train_program, train_opt_loss_name)

    # dump mode structure
    if config['Global']['debug']:
        if train_alg_type == 'rec' and 'attention' in config['Global'][
                'loss_type']:
            logger.warning('Does not suport dump attention...')
        else:
            summary(train_program)

    init_model(config, train_program, exe)

    train_info_dict = {'compile_program':train_compile_program,\
        'train_program':train_program,\
        'reader':train_loader,\
        'fetch_name_list':train_fetch_name_list,\
        'fetch_varname_list':train_fetch_varname_list,\
        'model_average': model_average}

    eval_info_dict = {'program':eval_program,\
        'reader':eval_reader,\
        'fetch_name_list':eval_fetch_name_list,\
        'fetch_varname_list':eval_fetch_varname_list}

    if train_alg_type == 'det':
        program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
    elif train_alg_type == 'rec':
        program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
    else:
        program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict)
示例#2
0
def train(train_dataset, test_dataset):
    startup_program = fluid.default_startup_program()
    main_program = fluid.default_main_program()
    inference_program = fluid.default_main_program().clone(for_test=True)

    train_reader = paddle.batch(paddle.reader.shuffle(train_dataset,
                                                      buf_size=500),
                                batch_size=BATCH_SIZE)
    test_reader = paddle.batch(paddle.reader.shuffle(test_dataset,
                                                     buf_size=500),
                               batch_size=BATCH_SIZE)

    img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')

    prediction, avg_loss, acc = net_conf(img, label, NUM_CLASSES)

    optimizer = fluid.optimizer.RMSProp(learning_rate=0.001)
    optimizer.minimize(avg_loss)

    place = fluid.CPUPlace()
    exe = fluid.Executor(place)

    print('Summary')
    summary(main_program)

    feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
    exe.run(startup_program)
    epochs = [epoch_id for epoch_id in range(EPOCHS)]

    # train
    step = 0
    for epoch_id in epochs:
        print("Epoch %d" % (epoch_id))
        for step_id, data in enumerate(train_reader()):
            metrics = exe.run(main_program,
                              feed=feeder.feed(data),
                              fetch_list=[avg_loss, acc])
            if step % 100 == 0:
                print("Pass %d, Cost %f" % (step, metrics[0]))
            step += 1

    # test
    total_acc = 0
    step = 0
    for step_id, data in enumerate(test_reader()):
        metrics = exe.run(main_program,
                          feed=feeder.feed(data),
                          fetch_list=[avg_loss, acc])
        total_acc += metrics[1]
        step += 1

    print("Acc: %f" % (total_acc / step))
示例#3
0
def main():
    config = program.load_config(FLAGS.config)
    program.merge_config(FLAGS.opt)
    logger.info(config)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    program.check_gpu(use_gpu)

    alg = config['Global']['algorithm']
    assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
    if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
        config['Global']['char_ops'] = CharacterOps(config['Global'])

    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    startup_program = fluid.Program()
    train_program = fluid.Program()
    train_build_outputs = program.build(config,
                                        train_program,
                                        startup_program,
                                        mode='train')
    train_loader = train_build_outputs[0]
    train_fetch_name_list = train_build_outputs[1]
    train_fetch_varname_list = train_build_outputs[2]
    train_opt_loss_name = train_build_outputs[3]

    eval_program = fluid.Program()
    eval_build_outputs = program.build(config,
                                       eval_program,
                                       startup_program,
                                       mode='eval')
    eval_fetch_name_list = eval_build_outputs[1]
    eval_fetch_varname_list = eval_build_outputs[2]
    eval_program = eval_program.clone(for_test=True)

    train_reader = reader_main(config=config, mode="train")
    train_loader.set_sample_list_generator(train_reader, places=place)

    eval_reader = reader_main(config=config, mode="eval")

    exe = fluid.Executor(place)
    exe.run(startup_program)

    # compile program for multi-devices
    train_compile_program = program.create_multi_devices_program(
        train_program, train_opt_loss_name)

    # dump mode structure
    if config['Global']['debug']:
        if 'attention' in config['Global']['loss_type']:
            logger.warning('Does not suport dump attention...')
        else:
            summary(train_program)

    init_model(config, train_program, exe)

    train_info_dict = {'compile_program':train_compile_program,\
        'train_program':train_program,\
        'reader':train_loader,\
        'fetch_name_list':train_fetch_name_list,\
        'fetch_varname_list':train_fetch_varname_list}

    eval_info_dict = {'program':eval_program,\
        'reader':eval_reader,\
        'fetch_name_list':eval_fetch_name_list,\
        'fetch_varname_list':eval_fetch_varname_list}

    if alg in ['EAST', 'DB']:
        program.train_eval_det_run(config, exe, train_info_dict,
                                   eval_info_dict)
    else:
        program.train_eval_rec_run(config, exe, train_info_dict,
                                   eval_info_dict)