Exemplo n.º 1
0
def do_train(args):
    if args.use_gpu:
        rank = dist.get_rank()
        trainer_count = dist.get_world_size()
    else:
        rank = 0
        trainer_count = 1
        paddle.set_device("cpu")

    if trainer_count > 1:
        dist.init_parallel_env()

    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    vocab = get_lm_vocab(args)
    train_loader = get_lm_data_loader(args, vocab, "train")
    eval_loader = get_lm_data_loader(args, vocab, "valid")

    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    mem_transformer = MemTransformerLM(args.ntokens,
                                       args.n_layer,
                                       args.n_head,
                                       args.d_model,
                                       args.d_head,
                                       args.d_inner_hid,
                                       args.dropout,
                                       args.attn_dropout,
                                       tie_weight=args.tie_weight,
                                       d_embed=args.d_model,
                                       div_val=args.div_val,
                                       tie_projs=tie_projs,
                                       normalize_before=args.normalize_before,
                                       tgt_len=args.tgt_len,
                                       ext_len=args.ext_len,
                                       mem_len=args.mem_len,
                                       cutoffs=cutoffs,
                                       same_length=args.same_length,
                                       attn_type=args.attn_type,
                                       clamp_len=args.clamp_len,
                                       sample_softmax=args.sample_softmax)

    if args.scheduler == 'cosine':
        scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
            learning_rate=args.learning_rate,
            T_max=args.max_step,
            eta_min=args.eta_min)
    elif args.scheduler == 'noam':
        scheduler = paddle.optimizer.lr.NoamDecay(
            d_model=args.d_model,
            warmup_steps=args.warmup_steps,
            learning_rate=args.learning_rate)
    elif args.scheduler == 'dev_perf':
        # fluid api
        scheduler = paddle.fluid.dygraph.ReduceLROnPlateau(
            learning_rate=args.learning_rate,
            decay_rate=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min)
    elif args.scheduler == 'constant':
        scheduler = args.learning_rate

    clip = paddle.nn.ClipGradByGlobalNorm(args.clip)
    if args.optim.lower() == 'momentum':
        optimizer = paddle.optimizer.Momentum(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            momentum=args.mom,
            grad_clip=clip)
    elif args.optim.lower() == 'adam':
        optimizer = paddle.optimizer.Adam(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=eval(args.eps),
            grad_clip=clip)
    elif args.optim.lower() == 'adagrad':
        optimizer = paddle.optimizer.Adagrad(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            grad_clip=clip)

    # Init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        model_dict = paddle.load(
            os.path.join(args.init_from_checkpoint,
                         "mem_transformer.pdparams"))
        opt_dict = paddle.load(
            os.path.join(args.init_from_checkpoint, "mem_transformer.pdopt"))
        mem_transformer.set_state_dict(model_dict)
        optimizer.set_state_dict(opt_dict)
        print("loaded from checkpoint.")
    # Init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        model_dict = paddle.load(
            os.path.join(args.init_from_pretrain_model,
                         "mem_transformer.pdparams"))
        mem_transformer.set_state_dict(model_dict)
        print("loaded from pre-trained model.")

    if trainer_count > 1:
        mem_transformer = paddle.DataParallel(mem_transformer)

    step_idx = 0
    train_loss = 0.0

    log_start_time = time.time()

    for pass_id in range(args.epoch):
        batch_id = 0

        mems = tuple()
        for input_data in train_loader:
            (src, target, seq_len) = input_data
            ret = mem_transformer(src, target, *mems)
            loss = ret[0]
            mems = ret[1:]
            train_loss += loss.numpy()

            loss.backward()
            optimizer.step()
            optimizer.clear_grad()

            if step_idx > 0 and step_idx % args.print_step == 0 and rank == 0:
                cur_loss = train_loss / args.print_step
                elapsed = time.time() - log_start_time
                if args.scheduler == "constant":
                    lr = optimizer.get_lr()
                else:
                    lr = scheduler.get_lr()
                logger_info = "step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, " \
                              "speed: %f ms/batch, loss: %f" % \
                              (step_idx, pass_id, batch_id, lr,
                               elapsed * 1000.0 / args.print_step, cur_loss)
                if args.dataset in ["enwik8", "text8"]:
                    logger_info = logger_info + ", bpc: %f" % (cur_loss /
                                                               np.log(2))
                else:
                    logger_info = logger_info + ", ppl: %f" % (
                        np.exp(cur_loss))

                logger.info(logger_info)
                train_loss = 0.0
                log_start_time = time.time()

            if step_idx % args.save_step == 0 and step_idx != 0:
                # Do validation.
                mem_transformer.eval()

                # TODO(FrostML): simplify this.
                if args.mem_len == 0:
                    if dist.get_world_size() == 1:
                        mem_transformer.reset_length(tgt_len=args.eval_tgt_len,
                                                     ext_len=args.ext_len +
                                                     args.tgt_len -
                                                     args.eval_tgt_len,
                                                     mem_len=args.mem_len)
                    else:
                        mem_transformer._layers.reset_length(
                            tgt_len=args.eval_tgt_len,
                            ext_len=args.ext_len + args.tgt_len -
                            args.eval_tgt_len,
                            mem_len=args.mem_len)
                else:
                    if dist.get_world_size() == 1:
                        mem_transformer.reset_length(tgt_len=args.eval_tgt_len,
                                                     ext_len=args.ext_len,
                                                     mem_len=args.mem_len +
                                                     args.tgt_len -
                                                     args.eval_tgt_len)
                    else:
                        mem_transformer._layers.reset_length(
                            tgt_len=args.eval_tgt_len,
                            ext_len=args.ext_len,
                            mem_len=args.mem_len + args.tgt_len -
                            args.eval_tgt_len)

                total_len, total_loss = 0, 0.

                eval_mems = tuple()
                with paddle.no_grad():
                    for i, (src, target, seq_len) in enumerate(eval_loader):
                        if args.max_eval_steps > 0 and i >= args.max_eval_steps:
                            break
                        ret = mem_transformer(src, target, *eval_mems)
                        loss, eval_mems = ret[0], ret[1:]
                        seq_len = seq_len.numpy()
                        eval_cur_loss = seq_len * loss.numpy()
                        total_loss += eval_cur_loss
                        total_len += seq_len
                    eval_loss = total_loss / total_len

                logger_info = "Validation, step_idx: %d, validation loss: %f" % \
                            (step_idx, eval_loss)
                if args.dataset in ['enwik8', 'text8']:
                    logger_info = logger_info + ", bpc: %f" % (eval_loss /
                                                               np.log(2))
                else:
                    logger_info = logger_info + ", ppl: %f" % (
                        np.exp(eval_loss))
                logger.info(logger_info)

                if args.save_model and rank == 0:
                    model_dir = os.path.join(
                        args.save_model,
                        "step_" + str(step_idx) + "_" + str(eval_loss))
                    if not os.path.exists(model_dir):
                        os.makedirs(model_dir)
                    paddle.save(
                        mem_transformer.state_dict(),
                        os.path.join(model_dir, "mem_transformer.pdparams"))
                    paddle.save(
                        optimizer.state_dict(),
                        os.path.join(model_dir, "mem_transformer.pdopt"))

                if args.scheduler == 'dev_perf':
                    scheduler.step(eval_loss)

                # TODO(FrostML): simplify this.
                if dist.get_world_size() == 1:
                    mem_transformer.reset_length(tgt_len=args.tgt_len,
                                                 ext_len=args.ext_len,
                                                 mem_len=args.mem_len)
                else:
                    mem_transformer._layers.reset_length(tgt_len=args.tgt_len,
                                                         ext_len=args.ext_len,
                                                         mem_len=args.mem_len)

                mem_transformer.train()

            step_idx += 1
            batch_id += 1
            if args.scheduler in ['cosine', 'dev_perf']:
                if step_idx < args.warmup_steps:
                    curr_lr = args.learning_rate * step_idx / args.warmup_steps
                    scheduler.base_lr = curr_lr
                else:
                    if args.scheduler == 'cosine':
                        scheduler.step()
            elif args.scheduler == 'constant':
                if step_idx < args.warmup_steps:
                    curr_lr = args.learning_rate * step_idx / args.warmup_steps
                    optimizer.set_lr(curr_lr)
            elif args.scheduler == 'noam':
                scheduler.step()
        if step_idx >= args.max_step:
            break

    if args.save_model and rank == 0:
        model_dir = os.path.join(args.save_model, "step_final")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        paddle.save(mem_transformer.state_dict(),
                    os.path.join(model_dir, "mem_transformer.pdparams"))
        paddle.save(optimizer.state_dict(),
                    os.path.join(model_dir, "mem_transformer.pdopt"))
Exemplo n.º 2
0
def do_eval(args):
    assert args.ext_len >= 0, 'Extended context length must be no less than 0'

    def _evaluate(loader):
        total_len, total_loss = 0, 0.

        eval_mems = tuple()
        for i, (src, target, seq_len) in enumerate(loader):
            if args.max_eval_steps > 0 and i >= args.max_eval_steps:
                break
            ret = mem_transformer(src, target, *eval_mems)
            loss, eval_mems = ret[0], ret[1:]
            eval_cur_loss = seq_len * loss.numpy()
            total_loss += eval_cur_loss
            total_len += seq_len
        return total_loss / total_len

    def _logger(loss):
        if args.dataset in ['enwik8', 'text8']:
            logger_info = "loss: %f, bpc: %f" % \
                          (loss, loss / np.log(2))
        else:
            logger_info = "loss: %f, ppl: %.2f" % \
                          (loss, np.exp(loss))
        return logger_info

    if not args.use_gpu:
        paddle.set_device("cpu")

    vocab = get_lm_vocab(args)
    eval_loader = get_lm_data_loader(args, vocab, "valid")
    test_loader = get_lm_data_loader(args, vocab, "test")

    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    mem_transformer = MemTransformerLM(args.ntokens,
                                       args.n_layer,
                                       args.n_head,
                                       args.d_model,
                                       args.d_head,
                                       args.d_inner_hid,
                                       args.dropout,
                                       args.attn_dropout,
                                       tie_weight=args.tie_weight,
                                       d_embed=args.d_model,
                                       div_val=args.div_val,
                                       tie_projs=tie_projs,
                                       normalize_before=args.normalize_before,
                                       tgt_len=args.tgt_len,
                                       ext_len=args.ext_len,
                                       mem_len=args.mem_len,
                                       cutoffs=cutoffs,
                                       same_length=args.same_length,
                                       attn_type=args.attn_type,
                                       clamp_len=args.clamp_len,
                                       sample_softmax=args.sample_softmax)

    assert args.init_from_params, (
        "Please set init_from_params to load the infer model.")

    model_dict = paddle.load(
        os.path.join(args.init_from_params, "mem_transformer.pdparams"))
    mem_transformer.load_dict(model_dict)

    logger.info(
        "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".
        format(args.eval_batch_size, args.tgt_len, args.ext_len, args.mem_len,
               args.clamp_len))

    mem_transformer.reset_length(args.tgt_len, args.ext_len, args.mem_len)

    test_loss = None
    valid_loss = None
    if args.mode == 'all':
        test_loss = _evaluate(test_loader)
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'valid':
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'test':
        test_loss = _evaluate(test_loader)

    logger_info = ''
    if valid_loss is not None:
        logger_info = logger_info + "validation loss: " + _logger(
            valid_loss) + " | "
    if test_loss is not None:
        logger_info = logger_info + "test loss: " + _logger(test_loss) + " | "
    logger.info(logger_info)