Ejemplo n.º 1
0
def train():
    if cfg.train_model == 'dnn':
        model = DNN()

    inputs = model.input_data()
    avg_cost, auc_var = model.net(inputs)

    optimizer = fluid.optimizer.Adam(cfg.learning_rate)
    optimizer.minimize(avg_cost)

    place = fluid.CUDAPlace(0) if cfg.use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    dataset, file_list = get_dataset(inputs)

    logger.info("Training Begin")
    for epoch in range(cfg.epoches):
        random.shuffle(file_list)
        dataset.set_filelist(file_list)

        start_time = time.time()
        exe.train_from_dataset(
            program=fluid.default_main_program(),
            dataset=dataset,
            fetch_list=[avg_cost, auc_var],
            fetch_info=['Epoch {} cost: '.format(epoch + 1), ' - auc: '],
            print_period=cfg.log_interval,
            debug=False)
        end_time = time.time()
        logger.info("epoch %d finished, use time = %ds \n" %
                    ((epoch + 1), end_time - start_time))

        if (epoch + 1) % cfg.save_interval == 0:
            model_path = os.path.join(str(cfg.save_path), model.name,
                                      model.name + "_epoch_" + str(epoch + 1))
            if not os.path.isdir(model_path):
                os.makedirs(model_path)
            logger.info("saving model to %s \n" % (model_path))
            fluid.save(fluid.default_main_program(),
                       os.path.join(model_path, "checkpoint"))
    logger.info("Done.")
Ejemplo n.º 2
0
def evaluate():
    place = fluid.CUDAPlace(0) if cfg.use_cuda else fluid.CPUPlace()
    inference_scope = fluid.Scope()
    test_files = [
        os.path.join(cfg.evaluate_file_path, x)
        for x in os.listdir(cfg.evaluate_file_path)
    ]
    dataset = CriteoDataset()
    test_reader = paddle.batch(dataset.test(test_files),
                               batch_size=cfg.batch_size)

    startup_program = fluid.framework.Program()
    test_program = fluid.framework.Program()
    model = DNN()
    model_path = os.path.join(cfg.save_path,
                              model.name + "_epoch_" + str(cfg.test_epoch),
                              "checkpoint")

    with fluid.framework.program_guard(test_program, startup_program):
        with fluid.unique_name.guard():
            inputs = model.input_data()
            loss, auc_var = model.net(inputs)

            exe = fluid.Executor(place)
            feeder = fluid.DataFeeder(feed_list=inputs, place=place)

            fluid.load(fluid.default_main_program(), model_path, exe)

            auc_states_names = [
                '_generated_var_0', '_generated_var_1', '_generated_var_2',
                '_generated_var_3'
            ]
            for var in auc_states_names:
                set_zero(var, scope=inference_scope, place=place)

            run_index = 0
            infer_auc = 0
            L = []
            for batch_id, data_test in enumerate(test_reader()):
                loss_val, auc_val = exe.run(test_program,
                                            feed=feeder.feed(data_test),
                                            fetch_list=[loss, auc_var])
                run_index += 1
                infer_auc = auc_val
                L.append(loss_val / cfg.batch_size)
                if batch_id % cfg.log_interval == 0:
                    logger.info("TEST --> batch: {} loss: {} auc: {}".format(
                        batch_id, loss_val / cfg.batch_size, auc_val))

            infer_loss = np.mean(L)
            infer_result = {}
            infer_result['loss'] = infer_loss
            infer_result['auc'] = infer_auc
            if not os.path.isdir(cfg.log_dir):
                os.makedirs(cfg.log_dir)
            log_path = os.path.join(cfg.log_dir,
                                    model.name + '_infer_result.log')

            logger.info(str(infer_result))
            with open(log_path, 'w+') as f:
                f.write(str(infer_result))
            logger.info("Done.")
    return infer_result