コード例 #1
0
ファイル: train.py プロジェクト: shark803/textsum
def train(args):
    """
    train models
    :return:
    """

    trainer_count = fluid.dygraph.parallel.Env().nranks
    if not args.use_cpu:
        place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
            if args.use_data_parallel else fluid.CUDAPlace(0)
    else:
        place = fluid.cpu_places()[0]
    with fluid.dygraph.guard(place):
        if args.use_data_parallel:
            strategy = fluid.dygraph.parallel.prepare_context()

        # define model
        transformer = TransFormer(
            'transformer', ModelHyperParams.src_vocab_size,
            ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
            ModelHyperParams.n_layer, ModelHyperParams.n_head,
            ModelHyperParams.d_key, ModelHyperParams.d_value,
            ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
            ModelHyperParams.prepostprocess_dropout,
            ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
            ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
            ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
        # define optimizer
        optimizer = fluid.optimizer.Adam(learning_rate=NoamDecay(
            ModelHyperParams.d_model, TrainTaskConfig.warmup_steps,
            TrainTaskConfig.learning_rate),
                                         beta1=TrainTaskConfig.beta1,
                                         beta2=TrainTaskConfig.beta2,
                                         epsilon=TrainTaskConfig.eps)
        #
        # load checkpoint
        if args.restore:
            model_dict, _ = fluid.load_dygraph(args.model_file)
            transformer.load_dict(model_dict)
            print("checkpoint loaded")
        if args.use_data_parallel:
            transformer = fluid.dygraph.parallel.DataParallel(
                transformer, strategy)

        # define data generator for training and validation
        train_gen_fn, train_total = transformer_reader('train')
        train_reader = paddle.batch(train_gen_fn,
                                    batch_size=TrainTaskConfig.batch_size)
        # wmt16.train(
        # ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size),
        # batch_size=TrainTaskConfig.batch_size)
        if args.use_data_parallel:
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)
        val_gen_fn, val_total = transformer_reader('val')
        val_reader = paddle.batch(val_gen_fn,
                                  batch_size=TrainTaskConfig.batch_size)
        # wmt16.test(ModelHyperParams.src_vocab_size,
        #            ModelHyperParams.trg_vocab_size),

        # loop for training iterations
        for i in range(TrainTaskConfig.pass_num):
            start = time.time()
            dy_step = 0
            sum_cost = 0
            transformer.train()
            for batch in train_reader():
                enc_inputs, dec_inputs, label, weights = prepare_train_input(
                    batch, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head)

                dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
                    enc_inputs, dec_inputs, label, weights)

                if args.use_data_parallel:
                    dy_avg_cost = transformer.scale_loss(dy_avg_cost)
                    dy_avg_cost.backward()
                    transformer.apply_collective_grads()
                else:
                    dy_avg_cost.backward()
                optimizer.minimize(dy_avg_cost)
                transformer.clear_gradients()

                dy_step = dy_step + 1
                if dy_step % 10 == 0:
                    print(
                        "\rpass: {}, batch: {}/{}, avg loss: {}, time: {}/{}".
                        format(i, dy_step, train_total // dy_step,
                               dy_avg_cost.numpy() * trainer_count,
                               time.time() - start, (time.time() - start) /
                               (dy_step + 1) * train_total),
                        end='')

            # switch to evaluation mode
            transformer.eval()
            sum_cost = 0
            token_num = 0
            for batch in val_reader():
                enc_inputs, dec_inputs, label, weights = prepare_train_input(
                    batch, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head)

                dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
                    enc_inputs, dec_inputs, label, weights)
                sum_cost += dy_sum_cost.numpy()
                token_num += dy_token_num.numpy()
            print("pass : {} finished, validation avg loss: {}".format(
                i, sum_cost / token_num))

        if fluid.dygraph.parallel.Env().dev_id == 0:
            fluid.save_dygraph(transformer.state_dict(), args.model_file)
コード例 #2
0
def do_train(args):
    if args.use_cuda:
        trainer_count = fluid.dygraph.parallel.Env().nranks
        place = fluid.CUDAPlace(fluid.dygraph.parallel.Env(
        ).dev_id) if trainer_count > 1 else fluid.CUDAPlace(0)
    else:
        trainer_count = 1
        place = fluid.CPUPlace()

    # define the data generator
    processor = reader.DataProcessor(
        fpattern=args.training_file,
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        token_delimiter=args.token_delimiter,
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size,
        device_count=trainer_count,
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        max_length=args.max_length,
        n_head=args.n_head)
    batch_generator = processor.data_generator(phase="train")
    if args.validation_file:
        val_processor = reader.DataProcessor(
            fpattern=args.validation_file,
            src_vocab_fpath=args.src_vocab_fpath,
            trg_vocab_fpath=args.trg_vocab_fpath,
            token_delimiter=args.token_delimiter,
            use_token_batch=args.use_token_batch,
            batch_size=args.batch_size,
            device_count=trainer_count,
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            shuffle=False,
            shuffle_batch=False,
            start_mark=args.special_token[0],
            end_mark=args.special_token[1],
            unk_mark=args.special_token[2],
            max_length=args.max_length,
            n_head=args.n_head)
        val_batch_generator = val_processor.data_generator(phase="train")
    if trainer_count > 1:  # for multi-process gpu training
        batch_generator = fluid.contrib.reader.distributed_batch_reader(
            batch_generator)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    with fluid.dygraph.guard(place):
        # set seed for CE
        random_seed = eval(str(args.random_seed))
        if random_seed is not None:
            fluid.default_main_program().random_seed = random_seed
            fluid.default_startup_program().random_seed = random_seed

        # define data loader
        train_loader = fluid.io.DataLoader.from_generator(capacity=10)
        train_loader.set_batch_generator(batch_generator, places=place)
        if args.validation_file:
            val_loader = fluid.io.DataLoader.from_generator(capacity=10)
            val_loader.set_batch_generator(val_batch_generator, places=place)

        # define model
        transformer = Transformer(
            args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
            args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
            args.d_inner_hid, args.prepostprocess_dropout,
            args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
            args.postprocess_cmd, args.weight_sharing, args.bos_idx,
            args.eos_idx)

        # define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps)

        # define optimizer
        optimizer = fluid.optimizer.Adam(
            learning_rate=NoamDecay(args.d_model, args.warmup_steps,
                                    args.learning_rate),
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
            parameter_list=transformer.parameters())

        ## init from some checkpoint, to resume the previous training
        if args.init_from_checkpoint:
            model_dict, opt_dict = fluid.load_dygraph(
                os.path.join(args.init_from_checkpoint, "transformer"))
            transformer.load_dict(model_dict)
            optimizer.set_dict(opt_dict)
        ## init from some pretrain models, to better solve the current task
        if args.init_from_pretrain_model:
            model_dict, _ = fluid.load_dygraph(
                os.path.join(args.init_from_pretrain_model, "transformer"))
            transformer.load_dict(model_dict)

        if trainer_count > 1:
            strategy = fluid.dygraph.parallel.prepare_context()
            transformer = fluid.dygraph.parallel.DataParallel(transformer,
                                                              strategy)

        # the best cross-entropy value with label smoothing
        loss_normalizer = -(
            (1. - args.label_smooth_eps) * np.log(
                (1. - args.label_smooth_eps)) + args.label_smooth_eps *
            np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))

        ce_time = []
        ce_ppl = []
        step_idx = 0

        # train loop
        for pass_id in range(args.epoch):
            epoch_start = time.time()

            batch_id = 0
            batch_start = time.time()
            interval_word_num = 0.0
            for input_data in train_loader():
                if args.max_iter and step_idx == args.max_iter:  #NOTE: used for benchmark
                    return
                batch_reader_end = time.time()

                (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
                 trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
                 lbl_weight) = input_data

                logits = transformer(src_word, src_pos, src_slf_attn_bias,
                                     trg_word, trg_pos, trg_slf_attn_bias,
                                     trg_src_attn_bias)

                sum_cost, avg_cost, token_num = criterion(logits, lbl_word,
                                                          lbl_weight)

                if trainer_count > 1:
                    avg_cost = transformer.scale_loss(avg_cost)
                    avg_cost.backward()
                    transformer.apply_collective_grads()
                else:
                    avg_cost.backward()

                optimizer.minimize(avg_cost)
                transformer.clear_gradients()

                interval_word_num += np.prod(src_word.shape)
                if step_idx % args.print_step == 0:
                    total_avg_cost = avg_cost.numpy() * trainer_count

                    if step_idx == 0:
                        logger.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                    else:
                        train_avg_batch_cost = args.print_step / (
                            time.time() - batch_start)
                        word_speed = interval_word_num / (
                            time.time() - batch_start)
                        logger.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
                            "words speed: %0.2f words/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             train_avg_batch_cost, word_speed))
                    batch_start = time.time()
                    interval_word_num = 0.0

                if step_idx % args.save_step == 0 and step_idx != 0:
                    # validation
                    if args.validation_file:
                        transformer.eval()
                        total_sum_cost = 0
                        total_token_num = 0
                        for input_data in val_loader():
                            (src_word, src_pos, src_slf_attn_bias, trg_word,
                             trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
                             lbl_word, lbl_weight) = input_data
                            logits = transformer(
                                src_word, src_pos, src_slf_attn_bias, trg_word,
                                trg_pos, trg_slf_attn_bias, trg_src_attn_bias)
                            sum_cost, avg_cost, token_num = criterion(
                                logits, lbl_word, lbl_weight)
                            total_sum_cost += sum_cost.numpy()
                            total_token_num += token_num.numpy()
                            total_avg_cost = total_sum_cost / total_token_num
                        logger.info("validation, step_idx: %d, avg loss: %f, "
                                    "normalized loss: %f, ppl: %f" %
                                    (step_idx, total_avg_cost,
                                     total_avg_cost - loss_normalizer,
                                     np.exp([min(total_avg_cost, 100)])))
                        transformer.train()

                    if args.save_model and (
                            trainer_count == 1 or
                            fluid.dygraph.parallel.Env().dev_id == 0):
                        model_dir = os.path.join(args.save_model,
                                                 "step_" + str(step_idx))
                        if not os.path.exists(model_dir):
                            os.makedirs(model_dir)
                        fluid.save_dygraph(
                            transformer.state_dict(),
                            os.path.join(model_dir, "transformer"))
                        fluid.save_dygraph(
                            optimizer.state_dict(),
                            os.path.join(model_dir, "transformer"))

                batch_id += 1
                step_idx += 1

            train_epoch_cost = time.time() - epoch_start
            ce_time.append(train_epoch_cost)
            logger.info("train epoch: %d, epoch_cost: %.5f s" %
                        (pass_id, train_epoch_cost))

        if args.save_model:
            model_dir = os.path.join(args.save_model, "step_final")
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            fluid.save_dygraph(transformer.state_dict(),
                               os.path.join(model_dir, "transformer"))
            fluid.save_dygraph(optimizer.state_dict(),
                               os.path.join(model_dir, "transformer"))

        if args.enable_ce:
            _ppl = 0
            _time = 0
            try:
                _time = ce_time[-1]
                _ppl = ce_ppl[-1]
            except:
                print("ce info error")
            print("kpis\ttrain_duration_card%s\t%s" % (trainer_count, _time))
            print("kpis\ttrain_ppl_card%s\t%f" % (trainer_count, _ppl))