Beispiel #1
0
def do_train(args):
    if args.use_gpu:
        rank = dist.get_rank()
        trainer_count = dist.get_world_size()
    else:
        rank = 0
        trainer_count = 1
        paddle.set_device("cpu")

    if trainer_count > 1:
        dist.init_parallel_env()

    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    vocab = get_lm_vocab(args)
    train_loader = get_lm_data_loader(args, vocab, "train")
    eval_loader = get_lm_data_loader(args, vocab, "valid")

    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    mem_transformer = MemTransformerLM(args.ntokens,
                                       args.n_layer,
                                       args.n_head,
                                       args.d_model,
                                       args.d_head,
                                       args.d_inner_hid,
                                       args.dropout,
                                       args.attn_dropout,
                                       tie_weight=args.tie_weight,
                                       d_embed=args.d_model,
                                       div_val=args.div_val,
                                       tie_projs=tie_projs,
                                       normalize_before=args.normalize_before,
                                       tgt_len=args.tgt_len,
                                       ext_len=args.ext_len,
                                       mem_len=args.mem_len,
                                       cutoffs=cutoffs,
                                       same_length=args.same_length,
                                       attn_type=args.attn_type,
                                       clamp_len=args.clamp_len,
                                       sample_softmax=args.sample_softmax)

    if args.scheduler == 'cosine':
        scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
            learning_rate=args.learning_rate,
            T_max=args.max_step,
            eta_min=args.eta_min)
    elif args.scheduler == 'noam':
        scheduler = paddle.optimizer.lr.NoamDecay(
            d_model=args.d_model,
            warmup_steps=args.warmup_steps,
            learning_rate=args.learning_rate)
    elif args.scheduler == 'dev_perf':
        # fluid api
        scheduler = paddle.fluid.dygraph.ReduceLROnPlateau(
            learning_rate=args.learning_rate,
            decay_rate=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min)
    elif args.scheduler == 'constant':
        scheduler = args.learning_rate

    clip = paddle.nn.ClipGradByGlobalNorm(args.clip)
    if args.optim.lower() == 'momentum':
        optimizer = paddle.optimizer.Momentum(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            momentum=args.mom,
            grad_clip=clip)
    elif args.optim.lower() == 'adam':
        optimizer = paddle.optimizer.Adam(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=eval(args.eps),
            grad_clip=clip)
    elif args.optim.lower() == 'adagrad':
        optimizer = paddle.optimizer.Adagrad(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            grad_clip=clip)

    # Init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        model_dict = paddle.load(
            os.path.join(args.init_from_checkpoint,
                         "mem_transformer.pdparams"))
        opt_dict = paddle.load(
            os.path.join(args.init_from_checkpoint, "mem_transformer.pdopt"))
        mem_transformer.set_state_dict(model_dict)
        optimizer.set_state_dict(opt_dict)
        print("loaded from checkpoint.")
    # Init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        model_dict = paddle.load(
            os.path.join(args.init_from_pretrain_model,
                         "mem_transformer.pdparams"))
        mem_transformer.set_state_dict(model_dict)
        print("loaded from pre-trained model.")

    if trainer_count > 1:
        mem_transformer = paddle.DataParallel(mem_transformer)

    step_idx = 0
    train_loss = 0.0

    log_start_time = time.time()

    for pass_id in range(args.epoch):
        batch_id = 0

        mems = tuple()
        for input_data in train_loader:
            (src, target, seq_len) = input_data
            ret = mem_transformer(src, target, *mems)
            loss = ret[0]
            mems = ret[1:]
            train_loss += loss.numpy()

            loss.backward()
            optimizer.step()
            optimizer.clear_grad()

            if step_idx > 0 and step_idx % args.print_step == 0 and rank == 0:
                cur_loss = train_loss / args.print_step
                elapsed = time.time() - log_start_time
                if args.scheduler == "constant":
                    lr = optimizer.get_lr()
                else:
                    lr = scheduler.get_lr()
                logger_info = "step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, " \
                              "speed: %f ms/batch, loss: %f" % \
                              (step_idx, pass_id, batch_id, lr,
                               elapsed * 1000.0 / args.print_step, cur_loss)
                if args.dataset in ["enwik8", "text8"]:
                    logger_info = logger_info + ", bpc: %f" % (cur_loss /
                                                               np.log(2))
                else:
                    logger_info = logger_info + ", ppl: %f" % (
                        np.exp(cur_loss))

                logger.info(logger_info)
                train_loss = 0.0
                log_start_time = time.time()

            if step_idx % args.save_step == 0 and step_idx != 0:
                # Do validation.
                mem_transformer.eval()

                # TODO(FrostML): simplify this.
                if args.mem_len == 0:
                    if dist.get_world_size() == 1:
                        mem_transformer.reset_length(tgt_len=args.eval_tgt_len,
                                                     ext_len=args.ext_len +
                                                     args.tgt_len -
                                                     args.eval_tgt_len,
                                                     mem_len=args.mem_len)
                    else:
                        mem_transformer._layers.reset_length(
                            tgt_len=args.eval_tgt_len,
                            ext_len=args.ext_len + args.tgt_len -
                            args.eval_tgt_len,
                            mem_len=args.mem_len)
                else:
                    if dist.get_world_size() == 1:
                        mem_transformer.reset_length(tgt_len=args.eval_tgt_len,
                                                     ext_len=args.ext_len,
                                                     mem_len=args.mem_len +
                                                     args.tgt_len -
                                                     args.eval_tgt_len)
                    else:
                        mem_transformer._layers.reset_length(
                            tgt_len=args.eval_tgt_len,
                            ext_len=args.ext_len,
                            mem_len=args.mem_len + args.tgt_len -
                            args.eval_tgt_len)

                total_len, total_loss = 0, 0.

                eval_mems = tuple()
                with paddle.no_grad():
                    for i, (src, target, seq_len) in enumerate(eval_loader):
                        if args.max_eval_steps > 0 and i >= args.max_eval_steps:
                            break
                        ret = mem_transformer(src, target, *eval_mems)
                        loss, eval_mems = ret[0], ret[1:]
                        seq_len = seq_len.numpy()
                        eval_cur_loss = seq_len * loss.numpy()
                        total_loss += eval_cur_loss
                        total_len += seq_len
                    eval_loss = total_loss / total_len

                logger_info = "Validation, step_idx: %d, validation loss: %f" % \
                            (step_idx, eval_loss)
                if args.dataset in ['enwik8', 'text8']:
                    logger_info = logger_info + ", bpc: %f" % (eval_loss /
                                                               np.log(2))
                else:
                    logger_info = logger_info + ", ppl: %f" % (
                        np.exp(eval_loss))
                logger.info(logger_info)

                if args.save_model and rank == 0:
                    model_dir = os.path.join(
                        args.save_model,
                        "step_" + str(step_idx) + "_" + str(eval_loss))
                    if not os.path.exists(model_dir):
                        os.makedirs(model_dir)
                    paddle.save(
                        mem_transformer.state_dict(),
                        os.path.join(model_dir, "mem_transformer.pdparams"))
                    paddle.save(
                        optimizer.state_dict(),
                        os.path.join(model_dir, "mem_transformer.pdopt"))

                if args.scheduler == 'dev_perf':
                    scheduler.step(eval_loss)

                # TODO(FrostML): simplify this.
                if dist.get_world_size() == 1:
                    mem_transformer.reset_length(tgt_len=args.tgt_len,
                                                 ext_len=args.ext_len,
                                                 mem_len=args.mem_len)
                else:
                    mem_transformer._layers.reset_length(tgt_len=args.tgt_len,
                                                         ext_len=args.ext_len,
                                                         mem_len=args.mem_len)

                mem_transformer.train()

            step_idx += 1
            batch_id += 1
            if args.scheduler in ['cosine', 'dev_perf']:
                if step_idx < args.warmup_steps:
                    curr_lr = args.learning_rate * step_idx / args.warmup_steps
                    scheduler.base_lr = curr_lr
                else:
                    if args.scheduler == 'cosine':
                        scheduler.step()
            elif args.scheduler == 'constant':
                if step_idx < args.warmup_steps:
                    curr_lr = args.learning_rate * step_idx / args.warmup_steps
                    optimizer.set_lr(curr_lr)
            elif args.scheduler == 'noam':
                scheduler.step()
        if step_idx >= args.max_step:
            break

    if args.save_model and rank == 0:
        model_dir = os.path.join(args.save_model, "step_final")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        paddle.save(mem_transformer.state_dict(),
                    os.path.join(model_dir, "mem_transformer.pdparams"))
        paddle.save(optimizer.state_dict(),
                    os.path.join(model_dir, "mem_transformer.pdopt"))
Beispiel #2
0
def main():
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(args.local_rank,
                                                   nproc_per_node,
                                                   args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')

    if args.type == 'pytorch':
        from mem_transformer import MemTransformerLM
    else:
        from inference.mem_transformer_jit import MemTransformerLM

    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir, debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'eval_log.log'

    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)
    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
        filemode='a',
    )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.profile:
        try:
            pyprof.init(enable_function_stack=True)
        except NameError:
            warnings.warn('Called pyprof.init() but pyprof is not available')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    if not args.no_env:
        log_env_info()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.model:
        model_path = args.model
    elif args.work_dir:
        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    else:
        raise RuntimeError(
            'Specify path to checkpoint using --model or --work_dir')

    if not args.manual_config:
        checkpoint = load_checkpoint(model_path)
        vocab_type = checkpoint['args'].vocab
    else:
        checkpoint = None
        vocab_type = args.manual_vocab

    if args.manual:
        vocab = checkpoint['vocab']

        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
            vocab.unk_idx = vocab.sym2idx['<unk>']

        text = " ".join(args.manual)
        tokenized = tokenize_raw(text)
        symbols = vocab.tokenize(tokenized, add_eos=True)
        tensor = vocab.convert_to_tensor(symbols)

        iter = data_utils.LMOrderedIterator(tensor,
                                            bsz=args.batch_size,
                                            bptt=args.tgt_len,
                                            device=device,
                                            ext_len=args.ext_len,
                                            warmup=False)
    else:
        # Load dataset
        corpus = get_lm_corpus(args.data, args.dataset, vocab_type)

        if args.split == 'valid' or args.split == 'test':
            iter = corpus.get_iterator(args.split,
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       mem_len=args.mem_len,
                                       ext_len=args.ext_len)
        else:
            raise RuntimeError('Unknown split')

    if args.fp16:
        dtype = torch.float16
        math_str = 'fp16'
    else:
        dtype = torch.float32
        math_str = 'fp32'

    if args.load_torchscript:
        model = torch.jit.load(args.load_torchscript)
    elif not args.manual_config:
        checkpoint['model_config']['tgt_len'] = args.tgt_len
        checkpoint['model_config']['ext_len'] = args.ext_len
        checkpoint['model_config']['mem_len'] = args.mem_len
        checkpoint['model_config']['clamp_len'] = args.clamp_len
        checkpoint['model_config']['same_length'] = args.same_length
        checkpoint['model_config']['dtype'] = dtype

        model = MemTransformerLM(**checkpoint['model_config'])
        if args.type == 'pytorch':
            model.load_state_dict(checkpoint['model_state'])
        elif args.type == 'torchscript':
            model.load_state_dict(checkpoint['model_state'], strict=False)
    elif args.manual_config:
        args.manual_config['tgt_len'] = args.tgt_len
        args.manual_config['ext_len'] = args.ext_len
        args.manual_config['mem_len'] = args.mem_len
        args.manual_config['clamp_len'] = args.clamp_len
        args.manual_config['same_length'] = args.same_length
        args.manual_config['dtype'] = dtype

        model = MemTransformerLM(**args.manual_config)

    model = model.eval()
    model = model.to(device)
    model = model.to(dtype)

    if args.type == 'torchscript' and not args.manual_config:
        state = checkpoint['model_state']

        tie_projs = checkpoint['model_config']['tie_projs']
        tie_weight = checkpoint['model_config']['tie_weight']
        div_val = checkpoint['model_config']['div_val']
        d_model = checkpoint['model_config']['d_model']
        d_embed = checkpoint['model_config']['d_embed']

        if div_val != 1 or d_model != d_embed:
            for i in range(len(model.word_emb.emb_projs)):
                model.word_emb.emb_projs[i] = state[
                    f'word_emb.emb_projs.{i}'].to(dtype)

        for i in range(len(model.crit.out_projs)):
            if div_val == 1:
                src = 0
            else:
                src = i
            if model.crit.out_projs[i] is not None:
                if tie_projs[i]:
                    model.crit.out_projs[i] = state[
                        f'word_emb.emb_projs.{src}'].to(dtype)
                else:
                    model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to(
                        dtype)

        for i in range(len(model.crit.out_layers_biases)):
            model.crit.out_layers_biases[i] = state[
                f'crit.out_layers_biases.{i}'].to(dtype)

        if tie_weight:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[
                    f'word_emb.emb_layers.{i}.weight'].to(dtype)
        else:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[
                    f'crit.out_layers_weights.{i}'].to(dtype)

        model = torch.jit.script(model)

    if args.type != 'pytorch':
        compile_model(model, device, args)

    if args.type == 'torchscript' and args.save_torchscript:
        torch.jit.save(model, args.save_torchscript)

    logging.info(f'Evaluating with: math {math_str} type {args.type} '
                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
                 f'clamp_len {args.clamp_len}')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['eval_throughput'] = AverageMeter(warmup=warmup,
                                             keep=args.save_data)
    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)

    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
        loss = evaluate(iter, model, meters, args.log_interval, args.max_size,
                        args.repeat)
    perplexity = math.exp(loss)
    log_str = format_log(loss, args.split, args)

    summary = {
        'eval_loss': loss,
        'eval_ppl': perplexity,
    }

    logging.info('=' * 100)
    logging.info(log_str)
    logging.info('=' * 100)

    if args.save_data:
        latency_data = np.array(meters['eval_latency'].vals)
        throughput_data = np.array(meters['eval_throughput'].vals)
        precision = 'fp16' if args.fp16 else 'fp32'
        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
        data_path = os.path.join(args.work_dir, data_fname)
        data = {
            'args': args,
            'throughput': throughput_data,
            'latency': latency_data,
        }
        with open(data_path, 'wb') as f:
            pickle.dump(data, f)
        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
        for p in args.percentiles:
            logging.info(
                f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms'
            )

        logging.info('=' * 100)

        summary.update({
            'eval_throughput': throughput_data.mean(),
            'eval_avg_latency': 1000 * latency_data.mean(),
        })
        for p in args.percentiles:
            summary[f'eval_{p}%_latency'] = 1000 * np.percentile(
                latency_data, p)

    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['eval_throughput'].avg,
    )
    if not passed:
        sys.exit(1)
Beispiel #3
0
def main():
    args = parse_args()

    if args.type == 'pytorch':
        from mem_transformer import MemTransformerLM
    else:
        from inference.mem_transformer_base_jit import MemTransformerLM

    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir, debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'log.log'

    log_file = os.path.join(args.work_dir, log_file)
    if args.debug:
        log_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
        filemode='a',
    )
    logging.info(args)

    if args.model:
        model_path = args.model
    elif args.work_dir:
        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    else:
        raise RuntimeError(
            'Specify path to checkpoint using --model or --work_dir')

    checkpoint = load_checkpoint(model_path)

    if args.manual:
        args.batch_size = 1
        vocab = checkpoint['vocab']

        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
            vocab.unk_idx = vocab.sym2idx['<unk>']

        text = " ".join(args.manual)
        tokenized = tokenize_raw(text)
        symbols = vocab.tokenize(tokenized, add_eos=True)
        tensor = vocab.convert_to_tensor(symbols)

        iter = data_utils.LMOrderedIterator(tensor,
                                            bsz=args.batch_size,
                                            bptt=args.tgt_len,
                                            device=device,
                                            ext_len=args.ext_len)
    else:
        # Load dataset
        corpus = get_lm_corpus(args.data, args.dataset,
                               checkpoint['args'].vocab)

        if args.split == 'valid':
            iter = corpus.get_iterator('valid',
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
        elif args.split == 'test':
            iter = corpus.get_iterator('test',
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
        else:
            raise RuntimeError('Unknown split')

    if args.fp16:
        dtype = torch.float16
        math_str = 'fp16'
    else:
        dtype = torch.float32
        math_str = 'fp32'

    if args.load_torchscript:
        model = torch.jit.load(args.load_torchscript)

    else:
        checkpoint['model_config']['tgt_len'] = args.tgt_len
        checkpoint['model_config']['ext_len'] = args.ext_len
        checkpoint['model_config']['mem_len'] = args.mem_len
        checkpoint['model_config']['clamp_len'] = args.clamp_len
        checkpoint['model_config']['same_length'] = args.same_length
        checkpoint['model_config']['dtype'] = dtype

        model = MemTransformerLM(**checkpoint['model_config'])
        model.load_state_dict(checkpoint['model_state'])

    model = model.eval()
    model = model.to(device)

    model = model.float()
    if args.fp16:
        model = model.half()

    if args.type != 'pytorch':
        compile_model(model, device, args)

    if args.type == 'torchscript' and args.save_torchscript:
        torch.jit.save(model, args.save_torchscript)

    logging.info(f'Evaluating with: math {math_str} type {args.type} '
                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
                 f'clamp_len {args.clamp_len}')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 1
    meters['eval_throughput'] = AverageMeter(warmup=warmup,
                                             keep=args.save_data)
    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)

    loss = evaluate(iter, model, meters, args.max_size, args.repeat)
    perplexity = math.exp(loss)
    log_str = format_log(loss, args.split, args)

    logging.info('=' * 100)
    logging.info(log_str)
    logging.info('=' * 100)

    if args.save_data:
        latency_data = np.array(meters['eval_latency'].vals)
        throughput_data = np.array(meters['eval_throughput'].vals)
        precision = 'fp16' if args.fp16 else 'fp32'
        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
        data_path = os.path.join(args.work_dir, data_fname)
        data = {
            'args': args,
            'throughput': throughput_data,
            'latency': latency_data,
        }
        with open(data_path, 'wb') as f:
            pickle.dump(data, f)
        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
        for p in args.percentiles:
            logging.info(
                f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms'
            )

        logging.info('=' * 100)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['eval_throughput'].avg,
    )
    if not passed:
        sys.exit(1)