예제 #1
0
    def run_offline_infer(self):
        logger.info("Run Offline Infer Begin")
        place = paddle.CPUPlace()
        self.exe = paddle.static.Executor(place)

        self.exe.run(paddle.static.default_startup_program())
        fleet.init_worker()

        init_model_path = config.get("runner.init_model_path")
        model_mode = config.get("runner.model_mode", 0)
        if fleet.is_first_worker():
            fleet.load_model(init_model_path, mode=model_mode)
        fleet.barrier_worker()

        logger.info("Prepare Dataset Begin.")
        prepare_data_start_time = time.time()
        dataset = self.wait_and_prepare_dataset()
        prepare_data_end_time = time.time()
        logger.info("Prepare Dataset Done, using time {} second.".format(
            prepare_data_end_time - prepare_data_start_time))

        infer_start_time = time.time()
        self.dataset_offline_infer(dataset)
        infer_end_time = time.time()
        logger.info("Infer Dataset Done, using time {} second.".format(
            infer_end_time - infer_start_time))
예제 #2
0
def test_barrier_worker():
    """test_barrier_worker"""
    assert fleet.barrier_worker() is None
    print("{} ... ok".format(sys._getframe().f_code.co_name))
예제 #3
0
def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    assert args.dp_degree * args.sharding_degree * args.mp_degree * args.pp_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        # if os.path.exists(log_writer_path):
        #     shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        data_holders = create_data_holder(args)
        # 0. input_ids,
        # 1. segment_ids,
        # 2. input_mask,
        # 3. masked_lm_positions,
        # 4. masked_lm_labels,
        # 5. next_sentence_labels

        [
            input_ids, segment_ids, input_mask, masked_lm_positions,
            masked_lm_labels, next_sentence_labels
        ] = data_holders

        tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

        train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
            args,
            data_file,
            tokenizer,
            data_world_size=topo.data_info.size,
            data_world_rank=topo.data_info.rank,
            max_seq_len=args.max_seq_len,
            places=paddle.static.cuda_places(),
            data_holders=data_holders,
            current_step=global_step)
        fleet.init(is_collective=True)

        if args.model_name_or_path in pretrained_models_list:
            model_config = model_class.pretrained_init_configuration[
                args.model_name_or_path]
            if model_config["vocab_size"] % 8 != 0:
                model_config["vocab_size"] += 8 - (model_config["vocab_size"] %
                                                   8)
            model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
            model_config[
                "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
            model = model_class(base_class(**model_config))
        else:
            model, _ = model_class.from_pretrained(
                args.model_name_or_path,
                hidden_dropout_prob=args.hidden_dropout_prob,
                attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            )

        # Create the model for the gpt pretrain
        prediction_scores, seq_relationship_score = model(
            input_ids=input_ids,
            token_type_ids=segment_ids,
            position_ids=None,
            attention_mask=input_mask,
            masked_positions=masked_lm_positions)

        criterion = criterion_class(with_nsp_loss=args.binary_head)
        if args.binary_head:
            lm_loss, sop_loss = criterion(prediction_scores,
                                          seq_relationship_score,
                                          masked_lm_labels,
                                          next_sentence_labels)
            loss = lm_loss + sop_loss
        else:
            loss = criterion(prediction_scores, seq_relationship_score,
                             masked_lm_labels)

        # Create the learning_rate sheduler and optimizer
        if args.decay_steps is None:
            args.decay_steps = args.max_steps

        # lr_scheduler = CosineAnnealingWithWarmupDecay(
        #     max_lr=args.max_lr,
        #     min_lr=args.min_lr,
        #     warmup_step=args.warmup_rate * args.max_steps,
        #     decay_step=args.decay_steps, last_epoch=global_step)

        lr_scheduler = LinearDecayWithWarmup(args.max_lr,
                                             args.max_steps,
                                             args.warmup_rate,
                                             last_epoch=global_step)

        clip = None
        if args.grad_clip > 0:
            clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                clip_norm=args.grad_clip)

        decay_param = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        logger.info("Using paddle.optimizer.AdamW.")
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            grad_clip=clip,
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_param)
        # alias
        optimizer.apply_optimize = optimizer._apply_optimize

        # if args.use_recompute:
        #     dist_strategy.recompute = True
        #     dist_strategy.recompute_configs = {
        #         "checkpoints": model.bert.checkpoints
        #     }

        # Use the fleet api to compile the distributed optimizer
        optimizer = fleet.distributed_optimizer(optimizer,
                                                strategy=dist_strategy)

        optimizer.minimize(loss)
        logger.info(f'final strategy: {fleet._final_strategy()}')
        logger.info("The training meta optimizer is/are %s" %
                    fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)

    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            paddle.static.load(main_program,
                               os.path.join(checkpoint_dir, "static_vars"),
                               exe)

    fetch_loss_vars = collections.OrderedDict()
    fetch_other_vars = collections.OrderedDict()
    fetch_loss_vars["loss"] = loss
    if args.binary_head:
        fetch_loss_vars["lm_loss"] = lm_loss
        fetch_loss_vars["sop_loss"] = sop_loss

    fetch_other_vars["learning_rate"] = main_program.global_block(
    ).vars["learning_rate_0"]

    additional_vars = collections.OrderedDict()
    if args.use_amp:
        for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]:
            additional_vars[key] = main_program.global_block().vars[key + "_0"]

    tic_train = time.time()
    while True:
        fetchs = []
        fetchs_keys = []
        if topo.is_last:
            fetchs = list(fetch_loss_vars.values()) + list(
                fetch_other_vars.values()) + list(additional_vars.values())
            fetchs_keys = list(fetch_loss_vars.keys()) + list(
                fetch_other_vars.keys()) + list(additional_vars.keys())

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue
            global_step += 1
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    res = collections.defaultdict(float)
                    for k, v in zip(fetchs_keys, ret):
                        res[k] = v[0]

                    speed = args.logging_freq / (time.time() - tic_train)

                    loss_info = "loss: %.6f, lm_loss: %.6f, sop_loss: %.6f"

                    loss_info = ", ".join([
                        "{}: {:.6f}".format(k, res[k])
                        for k in fetch_loss_vars.keys()
                    ])

                    common_loginfo = "global step %d, %s, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                        global_step, loss_info, speed,
                        speed * args.global_batch_size, res["learning_rate"])
                    additional_loginfo = ", ".join([
                        "{}: {}".format(k, res[k])
                        for k in additional_vars.keys()
                    ])
                    if additional_loginfo:
                        common_loginfo += ", " + additional_loginfo
                    logger.info(common_loginfo)
                    for k, v in res.items():
                        log_writer.add_scalar(k, v, global_step)

                tic_train = time.time()

            #if args.check_accuracy:
            #    if global_step >= args.max_steps:
            #        return
            #    else:
            #        continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)

                    step_config = {
                        "model_name": args.model_name_or_path,
                        "global_step": global_step,
                        "global_batch_size": args.global_batch_size,
                        "consumed_samples":
                        global_step * args.global_batch_size,
                    }

                    with open(os.path.join(output_dir, "config.yml"),
                              "w") as f:
                        yaml.dump(step_config,
                                  f,
                                  encoding='utf-8',
                                  allow_unicode=True)

                fleet.barrier_worker()

                logger.debug("saving models to {}".format(output_dir))
                if args.sharding_degree <= 1:
                    # Save on the first worker by default.
                    if worker_index == 0:
                        paddle.static.save(
                            main_program,
                            os.path.join(output_dir, "static_vars"))
                else:
                    # Use save_persistables in sharding, but more slower
                    save_persistables(exe,
                                      os.path.join(output_dir, "static_vars"),
                                      main_program)

            if global_step >= args.max_steps:
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
예제 #4
0
def get_global_auc(scope=fluid.global_scope(),
                   stat_pos="_generated_var_2",
                   stat_neg="_generated_var_3"):
    """
        Get global auc of all distributed workers.

        Args:
            scope(Scope): Scope object, default is fluid.global_scope()
            stat_pos(str): name of auc pos bucket Variable
            stat_neg(str): name of auc neg bucket Variable

        Returns:
            auc_value(float), total_ins_num(int)

        """
    if scope.find_var(stat_pos) is None or scope.find_var(stat_neg) is None:
        logger.info("not found auc bucket")
        return None
    fleet.barrier_worker()
    # auc pos bucket
    pos = np.array(scope.find_var(stat_pos).get_tensor())
    # auc pos bucket shape
    old_pos_shape = np.array(pos.shape)
    # reshape to one dim
    pos = pos.reshape(-1)
    #global_pos = np.copy(pos) * 0
    # mpi allreduce
    global_pos = fleet.util.all_reduce(pos)
    # reshape to its original shape
    global_pos = global_pos.reshape(old_pos_shape)
    # print('debug global auc global_pos: ', global_pos)

    # auc neg bucket
    neg = np.array(scope.find_var(stat_neg).get_tensor())
    old_neg_shape = np.array(neg.shape)
    neg = neg.reshape(-1)
    #global_neg = np.copy(neg) * 0
    global_neg = fleet.util.all_reduce(neg)
    global_neg = global_neg.reshape(old_neg_shape)
    # print('debug global auc global_neg: ', global_neg)

    # calculate auc
    num_bucket = len(global_pos[0])
    area = 0.0
    pos = 0.0
    neg = 0.0
    new_pos = 0.0
    new_neg = 0.0
    total_ins_num = 0
    for i in range(num_bucket):
        index = num_bucket - 1 - i
        new_pos = pos + global_pos[0][index]
        total_ins_num += global_pos[0][index]
        new_neg = neg + global_neg[0][index]
        total_ins_num += global_neg[0][index]
        area += (new_neg - neg) * (pos + new_pos) / 2
        pos = new_pos
        neg = new_neg

    if pos * neg == 0 or total_ins_num == 0:
        auc_value = 0.5
    else:
        auc_value = area / (pos * neg)

    fleet.barrier_worker()
    return auc_value
예제 #5
0
def get_global_metrics(scope=fluid.global_scope(),
                       stat_pos_name="_generated_var_2",
                       stat_neg_name="_generated_var_3",
                       sqrerr_name="sqrerr",
                       abserr_name="abserr",
                       prob_name="prob",
                       q_name="q",
                       pos_ins_num_name="pos",
                       total_ins_num_name="total"):
    if scope.find_var(stat_pos_name) is None or \
            scope.find_var(stat_neg_name) is None:
        logger.info("not found auc bucket")
        return [None] * 9
    elif scope.find_var(sqrerr_name) is None:
        logger.info("not found sqrerr_name=%s" % sqrerr_name)
        return [None] * 9
    elif scope.find_var(abserr_name) is None:
        logger.info("not found abserr_name=%s" % abserr_name)
        return [None] * 9
    elif scope.find_var(prob_name) is None:
        logger.info("not found prob_name=%s" % prob_name)
        return [None] * 9
    elif scope.find_var(q_name) is None:
        logger.info("not found q_name=%s" % q_name)
        return [None] * 9
    elif scope.find_var(pos_ins_num_name) is None:
        logger.info("not found pos_ins_num_name=%s" % pos_ins_num_name)
        return [None] * 9
    elif scope.find_var(total_ins_num_name) is None:
        logger.info("not found total_ins_num_name=%s" % \
                               total_ins_num_name)
        return [None] * 9

    # barrier worker to ensure all workers finished training
    fleet.barrier_worker()

    # get auc
    auc = get_global_auc(scope, stat_pos_name, stat_neg_name)
    pos = np.array(scope.find_var(stat_pos_name).get_tensor())
    # auc pos bucket shape
    old_pos_shape = np.array(pos.shape)
    # reshape to one dim
    pos = pos.reshape(-1)
    global_pos = np.copy(pos) * 0
    # mpi allreduce
    # fleet._role_maker._all_reduce(pos, global_pos)
    global_pos = fleet.util.all_reduce(pos)
    # reshape to its original shape
    global_pos = global_pos.reshape(old_pos_shape)
    # auc neg bucket
    neg = np.array(scope.find_var(stat_neg_name).get_tensor())
    old_neg_shape = np.array(neg.shape)
    neg = neg.reshape(-1)
    global_neg = np.copy(neg) * 0
    # fleet._role_maker._all_reduce(neg, global_neg)
    global_neg = fleet.util.all_reduce(neg)
    global_neg = global_neg.reshape(old_neg_shape)

    num_bucket = len(global_pos[0])

    def get_metric(name):
        metric = np.array(scope.find_var(name).get_tensor())
        old_metric_shape = np.array(metric.shape)
        metric = metric.reshape(-1)
        # print(name, 'ori value:', metric)
        global_metric = np.copy(metric) * 0
        # fleet._role_maker._all_reduce(metric, global_metric)
        global_metric = fleet.util.all_reduce(metric)
        global_metric = global_metric.reshape(old_metric_shape)
        # print(name, global_metric)
        return global_metric[0]

    global_sqrerr = get_metric(sqrerr_name)
    global_abserr = get_metric(abserr_name)
    global_prob = get_metric(prob_name)
    global_q_value = get_metric(q_name)
    # note: get ins_num from auc bucket is not actual value,
    # so get it from metric op
    pos_ins_num = get_metric(pos_ins_num_name)
    total_ins_num = get_metric(total_ins_num_name)
    neg_ins_num = total_ins_num - pos_ins_num

    mae = global_abserr / total_ins_num
    rmse = math.sqrt(global_sqrerr / total_ins_num)
    return_actual_ctr = pos_ins_num / total_ins_num
    predicted_ctr = global_prob / total_ins_num
    mean_predict_qvalue = global_q_value / total_ins_num
    copc = 0.0
    if abs(predicted_ctr > 1e-6):
        copc = return_actual_ctr / predicted_ctr

    # calculate bucket error
    last_ctr = -1.0
    impression_sum = 0.0
    ctr_sum = 0.0
    click_sum = 0.0
    error_sum = 0.0
    error_count = 0.0
    click = 0.0
    show = 0.0
    ctr = 0.0
    adjust_ctr = 0.0
    relative_error = 0.0
    actual_ctr = 0.0
    relative_ctr_error = 0.0
    k_max_span = 0.01
    k_relative_error_bound = 0.05
    for i in range(num_bucket):
        click = global_pos[0][i]
        show = global_pos[0][i] + global_neg[0][i]
        ctr = float(i) / num_bucket
        if abs(ctr - last_ctr) > k_max_span:
            last_ctr = ctr
            impression_sum = 0.0
            ctr_sum = 0.0
            click_sum = 0.0
        impression_sum += show
        ctr_sum += ctr * show
        click_sum += click
        if impression_sum == 0:
            continue
        adjust_ctr = ctr_sum / impression_sum
        if adjust_ctr == 0:
            continue
        relative_error = \
            math.sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum))
        if relative_error < k_relative_error_bound:
            actual_ctr = click_sum / impression_sum
            relative_ctr_error = abs(actual_ctr / adjust_ctr - 1)
            error_sum += relative_ctr_error * impression_sum
            error_count += impression_sum
            last_ctr = -1

    bucket_error = error_sum / error_count if error_count > 0 else 0.0

    return [
        auc, bucket_error, mae, rmse, return_actual_ctr, predicted_ctr, copc,
        mean_predict_qvalue,
        int(total_ins_num)
    ]
예제 #6
0
 def test_barrier_worker(self):
     role = role_maker.PaddleCloudRoleMaker(is_collective=True)
     fleet.init(role)
     if fleet.is_worker():
         fleet.barrier_worker()
예제 #7
0
    def run_worker(self):
        logger.info("Run Worker Begin")
        use_cuda = int(config.get("runner.use_gpu"))
        use_auc = config.get("runner.use_auc", False)
        place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
        self.exe = paddle.static.Executor(place)

        with open(
                "./{}_worker_main_program.prototxt".format(
                    fleet.worker_index()), 'w+') as f:
            f.write(str(paddle.static.default_main_program()))
        with open(
                "./{}_worker_startup_program.prototxt".format(
                    fleet.worker_index()), 'w+') as f:
            f.write(str(paddle.static.default_startup_program()))

        self.exe.run(paddle.static.default_startup_program())
        fleet.init_worker()

        save_model_path = self.config.get("runner.model_save_path")
        if save_model_path and (not os.path.exists(save_model_path)):
            os.makedirs(save_model_path)

        reader_type = self.config.get("runner.reader_type", None)
        epochs = int(self.config.get("runner.epochs"))
        sync_mode = self.config.get("runner.sync_mode")

        gpus_env = os.getenv("FLAGS_selected_gpus")
        self.PSGPU = paddle.fluid.core.PSGPU()
        gpuslot = [int(i) for i in range(1, self.model.sparse_inputs_slots)]
        self.PSGPU.set_slot_vector(gpuslot)
        self.PSGPU.init_gpu_ps([int(s) for s in gpus_env.split(",")])
        opt_info = paddle.fluid.default_main_program()._fleet_opt
        if use_auc is True:
            opt_info['stat_var_names'] = [
                self.model.stat_pos.name, self.model.stat_neg.name
            ]
        else:
            opt_info['stat_var_names'] = []

        for epoch in range(epochs):
            epoch_start_time = time.time()

            if sync_mode == "heter":
                self.heter_train_loop(epoch)
            elif sync_mode == "gpubox":
                self.reader._set_use_ps_gpu(1)
                self.dataset_train_loop(epoch)
            elif reader_type == "QueueDataset":
                self.dataset_train_loop(epoch)
            elif reader_type == "DataLoader":
                self.dataloader_train_loop(epoch)
            elif reader_type == None or reader_type == "RecDataset":
                self.recdataset_train_loop(epoch)

            epoch_time = time.time() - epoch_start_time
            epoch_speed = self.example_nums / epoch_time
            if use_auc is True:
                global_auc = auc(self.model.stat_pos, self.model.stat_neg,
                                 paddle.fluid.global_scope(), fleet.util)
                self.train_result_dict["auc"].append(global_auc)
                fleet_util.set_zero(self.model.stat_pos.name,
                                    paddle.fluid.global_scope())
                fleet_util.set_zero(self.model.stat_neg.name,
                                    paddle.fluid.global_scope())
                fleet_util.set_zero(self.model.batch_stat_pos.name,
                                    paddle.fluid.global_scope())
                fleet_util.set_zero(self.model.batch_stat_neg.name,
                                    paddle.fluid.global_scope())
                logger.info(
                    "Epoch: {}, using time {} second, ips {} {}/sec. auc: {}".
                    format(epoch, epoch_time, epoch_speed, self.count_method,
                           global_auc))
            else:
                logger.info(
                    "Epoch: {}, using time {} second, ips {} {}/sec.".format(
                        epoch, epoch_time, epoch_speed, self.count_method))
            self.train_result_dict["speed"].append(epoch_speed)

            model_dir = "{}/{}".format(save_model_path, epoch)
            if fleet.is_first_worker(
            ) and save_model_path and is_distributed_env():
                fleet.save_inference_model(
                    self.exe, model_dir,
                    [feed.name
                     for feed in self.input_data], self.inference_target_var)
            fleet.barrier_worker()
            self.reader.release_memory()
            self.PSGPU.end_pass()
            logger.info("finish {} epoch training....".format(epoch))
        self.PSGPU.finalize()
    def run_online_worker(self):
        logger.info("Run Online Worker Begin")
        use_cuda = int(config.get("runner.use_gpu"))
        place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
        self.exe = paddle.static.Executor(place)

        with open("./{}_worker_main_program.prototxt".format(
                fleet.worker_index()), 'w+') as f:
            f.write(str(paddle.static.default_main_program()))
        with open("./{}_worker_startup_program.prototxt".format(
                fleet.worker_index()), 'w+') as f:
            f.write(str(paddle.static.default_startup_program()))

        self.exe.run(paddle.static.default_startup_program())
        fleet.init_worker()

        self.online_intervals = get_online_pass_interval(
            self.split_interval, self.split_per_pass, False)
        if is_local(self.save_model_path) and self.save_model_path and (
                not os.path.exists(self.save_model_path)):
            os.makedirs(self.save_model_path)

        last_day, last_pass, last_path, xbox_base_key = get_last_save_model(
            self.save_model_path, self.hadoop_client)
        logger.info(
            "get_last_save_model last_day = {}, last_pass = {}, last_path = {}, xbox_base_key = {}".
            format(last_day, last_pass, last_path, xbox_base_key))
        if last_day != -1 and fleet.is_first_worker():
            load_model(last_path, 0, self.hadoop_client)
        fleet.barrier_worker()

        day = self.start_day
        infer_first = True
        while int(day) <= int(self.end_day):
            logger.info("training a new day {}, end_day = {}".format(
                day, self.end_day))
            if last_day != -1 and int(day) < last_day:
                day = get_next_day(day)
                continue
            # base_model_saved = False
            for pass_id in range(1, 1 + len(self.online_intervals)):
                print(last_day, day, last_pass, pass_id)
                if (last_day != -1 and int(day) == last_day) and (
                        last_pass != -1 and int(pass_id) <= last_pass):
                    continue
                if self.save_first_base and fleet.is_first_worker():
                    self.save_first_base = False
                    last_base_day, last_base_path, tmp_xbox_base_key = \
                        get_last_save_xbox_base(self.save_model_path, self.hadoop_client)
                    logger.info(
                        "get_last_save_xbox_base, last_base_day = {}, last_base_path = {}, tmp_xbox_base_key = {}".
                        format(last_base_day, last_base_path,
                               tmp_xbox_base_key))
                    if int(day) > last_base_day:
                        xbox_base_key = int(time.time())
                        save_xbox_model(self.save_model_path, day, -1,
                                        self.exe, self.inference_feed_vars,
                                        self.inference_target_var,
                                        self.hadoop_client)
                        write_xbox_donefile(
                            output_path=self.save_model_path,
                            day=day,
                            pass_id=-1,
                            xbox_base_key=xbox_base_key,
                            client=self.hadoop_client)
                    elif int(day) == last_base_day:
                        xbox_base_key = tmp_xbox_base_key
                fleet.barrier_worker()

                logger.info("training a new day = {} new pass = {}".format(
                    day, pass_id))
                logger.info("Day:{}, Pass: {}, Prepare Dataset Begin.".format(
                    day, pass_id))
                begin_train = time.time()
                begin = time.time()
                dataset = self.wait_and_prepare_dataset(day, pass_id)
                end = time.time()
                read_data_cost = (end - begin) / 60.0
                logger.info("Prepare Dataset Done, using time {} mins.".format(
                    read_data_cost))

                infer_cost = 0
                infer_metric_cost = 0
                if infer_first:
                    infer_first = False
                else:
                    logger.info("Day:{}, Pass: {}, Infering Dataset Begin.".
                                format(day, pass_id))
                    begin = time.time()
                    self.dataset_infer_loop(dataset, day, pass_id)
                    end = time.time()
                    infer_cost = (end - begin) / 60.0
                    logger.info("Infering Dataset Done, using time {} mins.".
                                format(infer_cost))
                    begin = time.time()
                    metric_str = get_global_metrics_str(fluid.global_scope(),
                                                        self.metric_list, "")
                    logger.info("Day:{}, Pass: {}, Infer Global Metric: {}".
                                format(day, pass_id, metric_str))
                    clear_metrics(fluid.global_scope(), self.metric_list,
                                  self.metric_types)
                    end = time.time()
                    infer_metric_cost = (end - begin) / 60.0

                logger.info("Day:{}, Pass: {}, Training Dataset Begin.".format(
                    day, pass_id))
                begin = time.time()
                self.dataset_train_loop(dataset, day, pass_id,
                                        self.need_train_dump)
                end = time.time()
                avg_cost = get_avg_cost_mins(end - begin)
                get_max_cost_mins(end - begin)
                get_min_cost_mins(end - begin)
                train_cost = avg_cost
                logger.info("Training Dataset Done, using time {} mins.".
                            format(train_cost))

                begin = time.time()
                dataset.release_memory()
                end = time.time()
                release_cost = (end - begin) / 60.0

                begin = time.time()
                metric_str = get_global_metrics_str(fluid.global_scope(),
                                                    self.metric_list, "")
                logger.info("Day:{}, Pass: {}, Train Global Metric: {}".format(
                    day, pass_id, metric_str))
                clear_metrics(fluid.global_scope(), self.metric_list,
                              self.metric_types)
                end = time.time()
                metric_cost = (end - begin) / 60
                end_train = time.time()
                total_cost = (end_train - begin_train) / 60
                other_cost = total_cost - read_data_cost - train_cost - release_cost - metric_cost - infer_cost - infer_metric_cost
                log_str = "finished train epoch %d time cost:%s min job time cost" \
                            ":[read_data:%s min][train: %s min][metric: %s min][release: %s min]" \
                            "[infer:%s min][infer_metric: %s min][other:%s min]" \
                              % (pass_id, total_cost, read_data_cost, train_cost, metric_cost, release_cost, infer_cost, infer_metric_cost, other_cost)
                logger.info(log_str)

                if self.need_infer_dump:
                    prepare_data_start_time = time.time()
                    dump_dataset = self.wait_and_prepare_infer_dataset(day,
                                                                       pass_id)
                    prepare_data_end_time = time.time()
                    logger.info(
                        "Prepare Infer Dump Dataset Done, using time {} second.".
                        format(prepare_data_end_time -
                               prepare_data_start_time))

                    dump_start_time = time.time()
                    self.dataset_infer_loop(dump_dataset, day, pass_id, True)
                    dump_end_time = time.time()
                    logger.info(
                        "Infer Dump Dataset Done, using time {} second.".
                        format(dump_end_time - dump_start_time))

                    dump_dataset.release_memory()

                if fleet.is_first_worker():
                    if pass_id % self.checkpoint_per_pass == 0:
                        save_model(self.exe, self.save_model_path, day,
                                   pass_id)
                        write_model_donefile(
                            output_path=self.save_model_path,
                            day=day,
                            pass_id=pass_id,
                            xbox_base_key=xbox_base_key,
                            client=self.hadoop_client)
                    if pass_id % self.save_delta_frequency == 0:
                        last_xbox_day, last_xbox_pass, last_xbox_path, _ = get_last_save_xbox(
                            self.save_model_path, self.hadoop_client)
                        if int(day) < last_xbox_day or int(
                                day) == last_xbox_day and int(
                                    pass_id) <= last_xbox_pass:
                            log_str = "delta model exists"
                            logger.info(log_str)
                        else:
                            save_xbox_model(self.save_model_path, day, pass_id,
                                            self.exe, self.inference_feed_vars,
                                            self.inference_target_var,
                                            self.hadoop_client)  # 1 delta
                            write_xbox_donefile(
                                output_path=self.save_model_path,
                                day=day,
                                pass_id=pass_id,
                                xbox_base_key=xbox_base_key,
                                client=self.hadoop_client,
                                hadoop_fs_name=self.hadoop_fs_name,
                                monitor_data=metric_str)
                fleet.barrier_worker()

            logger.info("shrink table")
            begin = time.time()
            fleet.shrink()
            end = time.time()
            logger.info("shrink table done, cost %s min" % (
                (end - begin) / 60.0))

            if fleet.is_first_worker():
                last_base_day, last_base_path, last_base_key = get_last_save_xbox_base(
                    self.save_model_path, self.hadoop_client)
                logger.info(
                    "one epoch finishes, get_last_save_xbox, last_base_day = {}, last_base_path = {}, last_base_key = {}".
                    format(last_base_day, last_base_path, last_base_key))
                next_day = get_next_day(day)
                if int(next_day) <= last_base_day:
                    logger.info("batch model/base xbox model exists")
                else:
                    xbox_base_key = int(time.time())
                    save_xbox_model(self.save_model_path, next_day, -1,
                                    self.exe, self.inference_feed_vars,
                                    self.inference_target_var,
                                    self.hadoop_client)
                    write_xbox_donefile(
                        output_path=self.save_model_path,
                        day=next_day,
                        pass_id=-1,
                        xbox_base_key=xbox_base_key,
                        client=self.hadoop_client,
                        hadoop_fs_name=self.hadoop_fs_name,
                        monitor_data=metric_str)
                    save_batch_model(self.exe, self.save_model_path, next_day)
                    write_model_donefile(
                        output_path=self.save_model_path,
                        day=next_day,
                        pass_id=-1,
                        xbox_base_key=xbox_base_key,
                        client=self.hadoop_client)
            fleet.barrier_worker()
            day = get_next_day(day)