コード例 #1
0
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)
    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])

if args.fp16:
    model = model.half()

if args.multi_gpu:
    model = model.to(device)
    if args.gpu0_bsz >= 0:
        para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
                                          model,
                                          dim=1).to(device)
コード例 #2
0
def main():
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(args.local_rank,
                                                   nproc_per_node,
                                                   args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(
        args.work_dir,
        args.dataset,
        args.append_dataset,
        args.append_time,
    )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = args.txtlog_file
    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)

    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
    )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.local_batch_size is not None:
        world_size = utils.distributed.get_world_size()
        args.batch_size = world_size * args.local_batch_size
        logging.info(f'--local_batch_size was set, adjusting global batch size'
                     f' to {args.batch_size} (local_batch_size * world_size)')
        if args.batch_size % args.batch_chunk != 0:
            raise RuntimeError('Batch size needs to be divisible by '
                               'batch chunk')

    if args.profile:
        try:
            pyprof.init(enable_function_stack=True)
        except NameError:
            warnings.warn('Called pyprof.init() but pyprof is not available')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    logging.info(f'world size: {utils.distributed.get_world_size()}')

    if not args.no_env:
        log_env_info()

    register_ignoring_timeout_handler()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    if args.mem_len == 0:
        eval_mem_len = 0
    else:
        eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len

    tr_iter = corpus.get_iterator('train',
                                  args.batch_size,
                                  args.tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
    }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum(
        [p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params,
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    scaler = None
    if args.fp16:
        if args.amp == 'pytorch':
            scaler = torch.cuda.amp.GradScaler()
        elif args.amp == 'apex':
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.apex_amp_opt_level,
            )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz //
                                              args.batch_chunk,
                                              model,
                                              dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         max_step -
                                                         args.warmup_step,
                                                         eta_min=args.eta_min)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse,
                max_step - args.warmup_step,
                eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.LambdaLR(optimizer_sparse,
                                                           lr_lambda=lr_lambda)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min,
        )
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    start_epoch = 1
    last_batch = 0
    last_iter = 0
    best_val_loss = None

    if args.restart:
        try:
            checkpoint = load_checkpoint(args.restart)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            if args.fp16:
                if args.amp == 'pytorch':
                    scaler.load_state_dict(checkpoint['amp_state'])
                elif args.amp == 'apex':
                    amp.load_state_dict(checkpoint['amp_state'])
            train_step = checkpoint['train_step']
            start_epoch = checkpoint['epoch']
            last_batch = checkpoint['batch']
            last_iter = checkpoint['last_iter']
            best_val_loss = checkpoint['best_val_loss']

            if train_step >= args.max_step:
                logging.info(
                    f'Loaded checkpoint after {train_step} steps, but '
                    f'this run was scheduled for a total of '
                    f'{args.max_step} steps, exiting')
                sys.exit(1)

            model.apply(functools.partial(update_dropout, args=args))
            model.apply(functools.partial(update_dropatt, args=args))
        except FileNotFoundError:
            logging.info(f'Could not load checkpoint from {args.restart}, '
                         f'starting training from random init')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
        with TimeoutHandler() as timeout_handler:
            try:
                for epoch in itertools.count(start=start_epoch):
                    if args.roll:
                        tr_iter.roll(seed=args.seed + epoch)
                    train_step, best_val_loss = train(
                        tr_iter, va_iter, model, para_model, model_config,
                        optimizer, optimizer_sparse, scheduler,
                        scheduler_sparse, scaler, vocab, epoch, last_batch,
                        last_iter, train_step, best_val_loss, meters,
                        timeout_handler, device, args)

                    last_batch = 0
                    last_iter = 0

                    if train_step == args.max_step:
                        logging.info('-' * 100)
                        logging.info('End of training')
                        break
            except KeyboardInterrupt:
                logging.info('-' * 100)
                logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    summary = {}
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and not args.no_eval and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
            test_loss = evaluate(te_iter, model, args)
            test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
        test_elapsed = time.time() - test_start_time

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'
                .format(test_elapsed, test_loss, test_loss / math.log(2)))
        else:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'
                .format(test_elapsed, test_loss, math.exp(test_loss)))
        logging.info('=' * 100)

        summary.update({
            'test_elapsed': test_elapsed,
            'test_loss': test_loss,
        })

        if args.dataset in ['enwik8', 'text8']:
            summary['test_bits_per_character'] = test_loss / math.log(2)
        else:
            summary['test_perplexity'] = math.exp(test_loss)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(
        f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    summary.update({
        'train_throughput': meters['train_throughput'].avg,
        'train_elapsed': elapsed / 60,
        'valid_loss': best_val_loss,
        'valid_perplexity': val_perplexity,
    })
    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=val_perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['train_throughput'].avg)
    if not passed:
        sys.exit(1)
コード例 #3
0
                                 tie_projs=tie_projs,
                                 pre_lnorm=default_args.pre_lnorm,
                                 tgt_len=default_args.tgt_len,
                                 ext_len=default_args.ext_len,
                                 mem_len=default_args.mem_len,
                                 cutoffs=cutoffs,
                                 same_length=default_args.same_length,
                                 attn_type=default_args.attn_type,
                                 clamp_len=default_args.clamp_len,
                                 sample_softmax=default_args.sample_softmax)
        initialization_func = partial(weights_init,
                                      init="normal",
                                      init_range=0.1,
                                      init_std=0.02,
                                      proj_init_std=0.01)
        model.apply(initialization_func)

        try:
            tr_iter = corpus.get_iterator('train',
                                          batch_size,
                                          default_args.tgt_len,
                                          device=device,
                                          ext_len=default_args.ext_len)
            para_model = parallelize_model(model, default_args)
            optimizers = build_optimizer(para_model,
                                         default_args,
                                         reload=False)
            optimizer, optimizer_sparse = optimizers
            schedulers = build_scheduler(optimizers, default_args)
            scheduler, scheduler_sparse = schedulers
            if default_args.cuda and args.fp16:
コード例 #4
0
def main():
    args = parse_args()

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(
        args.work_dir,
        args.dataset,
        args.append_dataset,
        args.append_time,
    )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'log.log'
    log_file = os.path.join(args.work_dir, log_file)

    if args.debug:
        log_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
    )
    logging.info(args)

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed + utils.distributed.get_rank())
    torch.manual_seed(args.seed + utils.distributed.get_rank())

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    tr_iter = corpus.get_iterator('train',
                                  args.batch_size,
                                  args.tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
    }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum(
        [p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params,
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    if args.fp16:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2',
        )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(
            model,
            delay_allreduce=True,
        )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz //
                                              args.batch_chunk,
                                              model,
                                              dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         max_step,
                                                         eta_min=args.eta_min)
        if args.sample_softmax > 0:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse, max_step, eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min,
        )
        if args.sample_softmax > 0:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    best_val_loss = None

    if args.restart:
        checkpoint = load_checkpoint(args.restart)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        scheduler.load_state_dict(checkpoint['scheduler_state'])
        if args.fp16:
            amp.load_state_dict(checkpoint['amp_state'])
        train_step = checkpoint['train_step']
        best_val_loss = checkpoint['best_val_loss']

        model.apply(functools.partial(update_dropout, args=args))
        model.apply(functools.partial(update_dropatt, args=args))

    meters = {}
    warmup = args.mem_len // args.tgt_len + 1
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    try:
        for epoch in itertools.count(start=1):
            if args.roll:
                tr_iter.roll()
            train_step, best_val_loss = train(tr_iter, va_iter, model,
                                              para_model, model_config,
                                              optimizer, optimizer_sparse,
                                              scheduler, scheduler_sparse,
                                              vocab, epoch, train_step,
                                              best_val_loss, meters, args)

            if train_step == args.max_step:
                logging.info('-' * 100)
                logging.info('End of training')
                break
    except KeyboardInterrupt:
        logging.info('-' * 100)
        logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        test_loss = evaluate(te_iter, model, args)
        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'
                .format(time.time() - test_start_time, test_loss,
                        test_loss / math.log(2)))
        else:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'
                .format(time.time() - test_start_time, test_loss,
                        math.exp(test_loss)))
        logging.info('=' * 100)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(
        f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=val_perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['train_throughput'].avg)
    if not passed:
        sys.exit(1)
コード例 #5
0
def main():
    global global_token_count, event_writer, train_step, train_loss, last_log_step, \
        best_val_loss, epoch, model

    if args.local_rank > 0:
        pass  # skip shutdown when rank is explicitly set + not zero rank
    else:
        os.system('shutdown -c')

    if not args.local:
        logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}'
        )
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)

    # log model info
    n_all_param = sum([p.nelement() for p in model.parameters()])
    log_tb('sizes/params', n_all_param)
    n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    log_tb('sizes/non_emb_params', n_nonemb_param)
    logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # optimizer
    if args.optim.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.mom)
    elif args.optim.lower() == 'lamb':
        optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd)
    else:
        assert args.optim.lower() == 'adam'
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         args.max_tokens //
                                                         1e6,
                                                         eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        scheduler = LRFinder(optimizer,
                             args.max_tokens,
                             init_value=args.lr / 1e3)
    elif args.scheduler == 'constant':
        pass

    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing

    if args.checkpoint:
        if global_rank == 0:
            util.restore_from_checkpoint(model=model,
                                         checkpoint_fn=args.checkpoint)

    model = model.to(device)
    if args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={'init_scale': 2**16},
                                   verbose=False)

    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  #, find_unused_parameters=True)

    if global_rank == 0:
        event_writer = SummaryWriter(args.logdir)

    event_writer.add_text('args', str(args))

    # test checkpoint writing
    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}')

    # Loop over epochs.
    train_step = 0
    train_loss = 0
    last_log_step = 0
    best_val_loss = None
    va_iter, te_iter = [
        corpus.get_dist_iterator(split,
                                 global_rank,
                                 max_rank,
                                 args.batch_size * 2,
                                 args.tgt_len,
                                 device=device,
                                 ext_len=args.ext_len)
        for split in ('valid', 'test')
    ]

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=1):
            train(va_iter, optimizer, scheduler)
    except KeyboardInterrupt:
        logger.info('-' * 100)
        logger.info('Exiting from training early')
    except StopIteration:
        pass

    # Eval one more time.
    evaluate_and_log(optimizer, va_iter, 'val', train_step=-1)

    # Load the best saved model.
    logger.info("Loading best checkpoint")
    model_file = os.path.join(args.logdir, 'model-best.pt')
    if os.path.exists(model_file):
        with open(model_file, 'rb') as model_f:
            with timeit('load'):
                if args.local:
                    model = torch.load(model_f)
                else:
                    model = torch.load(model_f,
                                       map_location=lambda storage, loc:
                                       storage.cuda(args.local_rank))
                    model = DistributedDataParallel(
                        model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)
    else:
        logger.warn('no model file, using current model for loss')

    # Run on test data.
    evaluate_and_log(optimizer, te_iter, 'test', -1)
コード例 #6
0
    sample_softmax=args.sample_softmax,
)
if args.restart:
    print('Restarting training from {args.restart_dir}')
    with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
        state = torch.load(f)
    if isinstance(state, MemTransformerLM):  # old format
        model = state
    else:
        model_params = state['model_params']
        model = MemTransformerLM(**model_params)
        model.load_state_dict(state['state_dict'])
        del state
    if not args.fp16:
        model = model.float()
    model.apply(update_dropout)
    model.apply(update_dropatt)
else:
    model = MemTransformerLM(**model_params)
    model.apply(weights_init)
    # ensure embedding init is not overridden by out_layer
    # in case of weight sharing
    model.word_emb.apply(weights_init)
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])

if args.fp16:
    model = model.half()

if args.multi_gpu:
    model = model.to(device)
コード例 #7
0
ファイル: train.py プロジェクト: manikbhandari/ner
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)

    model.apply(weights_init)  # TransformerlM weights will get initialized.
    model.word_emb.apply(
        weights_init
    )  # TODO: Not able to understand this. ensure embedding init is not overridden by out_layer in case of weight sharing

args.n_all_param = sum([p.nelement() for p in model.parameters()
                        ])  # total number of parameters
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()
                           ])  # non embedding parameters

if args.fp16: model = model.half()

if args.multi_gpu:
    model = model.to(device)
    if args.gpu0_bsz >= 0:
        para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
コード例 #8
0
def train_ts(args):
    def build_scheduler(optimizers, args):
        optimizer, optimizer_sparse = optimizers
        scheduler_sparse = None

        if args.scheduler == "cosine":
            # here we do not set eta_min to lr_min to be backward compatible
            # because in previous versions eta_min is default to 0
            # rather than the default value of lr_min 1e-6
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, args.max_step,
                eta_min=args.eta_min)  # should use eta_min arg

        elif args.scheduler == "inv_sqrt":
            # originally used for Transformer (in Attention is all you need)
            def lr_lambda(step):
                # return a multiplier instead of a learning rate
                if step == 0 and args.warmup_step == 0:
                    return 1.0
                else:
                    return (1.0 /
                            (step**0.5) if step > args.warmup_step else step /
                            (args.warmup_step**1.5))

            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lr_lambda=lr_lambda)

        elif args.scheduler == "dev_perf":
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )

        elif args.scheduler == "constant":
            pass

        else:
            raise ValueError(f"scheduler type {args.scheduler} not recognized")

        return scheduler, scheduler_sparse

    ###############################################################################
    # Training code
    ###############################################################################
    def evaluate(eval_iter, model):

        # Turn on evaluation mode which disables dropout.
        model.eval()

        # debug
        # If the model does not use memory at all, make the ext_len longer.
        # Otherwise, make the mem_len longer and keep the ext_len the same.
        # if default_args.mem_len == 0:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len + default_args.tgt_len -
        #                        default_args.eval_tgt_len, default_args.mem_len)
        # else:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len, default_args.mem_len +
        #                       default_args.tgt_len - default_args.eval_tgt_len)

        # Evaluation
        total_len, total_loss = 0, 0.0
        with torch.no_grad():
            mems = tuple()
            for i, (data, target, seq_len) in enumerate(eval_iter):
                if i >= args.max_eval_steps > 0:
                    break
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.mean()
                total_loss += seq_len * loss.float().item()
                total_len += seq_len

        # Switch back to the training mode
        # model.reset_length(default_args.tgt_len, default_args.ext_len,
        # default_args.mem_len)
        model.train()

        return total_loss / total_len

    # reverse distillation util
    def get_original_batches(model, tr_iter, integration_length):
        model.eval()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
            first_logits = [[] for _ in range(args.batch_chunk)]
        else:
            mems = None
            first_logits = []
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        with torch.no_grad():
            for batch, (data, target, seq_len) in enumerate(train_iter):
                if batch == integration_length:
                    break
                if args.batch_chunk > 1:
                    data_chunks = torch.chunk(data, args.batch_chunk, 1)
                    for i in range(args.batch_chunk):
                        data_i = data_chunks[i].contiguous()
                        logits, mems[i] = model._forward(data_i, mems=mems[i])
                        first_logits[i].append(logits.cpu())
                else:
                    logits, mems = model._forward(data, mems=mems)
                    first_logits.append(logits.cpu())
        return first_logits

    def build_optimizer(model, args, reload=False):
        optimizer_sparse = None
        if args.optim.lower() == "sgd":
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
        elif args.optim.lower() == "adam":
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        elif args.optim.lower() == "adagrad":
            optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        else:
            raise ValueError(f"optimizer type {args.optim} not recognized")

        if reload:
            if args.restart_from is not None:
                optim_name = f"optimizer_{args.restart_from}.pt"
            else:
                optim_name = "optimizer.pt"
            optim_file_name = os.path.join(args.restart_dir, optim_name)
            logging(f"reloading {optim_file_name}")
            if os.path.exists(os.path.join(args.restart_dir, optim_name)):
                with open(os.path.join(args.restart_dir, optim_name),
                          "rb") as optim_file:
                    opt_state_dict = torch.load(optim_file)
                    try:
                        optimizer.load_state_dict(opt_state_dict)
                    # in case the optimizer param groups aren't the same shape,
                    # merge them
                    except:
                        logging("merging optimizer param groups")
                        opt_state_dict["param_groups"][0]["params"] = [
                            param
                            for param_group in opt_state_dict["param_groups"]
                            for param in param_group["params"]
                        ]
                        opt_state_dict["param_groups"] = [
                            opt_state_dict["param_groups"][0]
                        ]
                        optimizer.load_state_dict(opt_state_dict)
            else:
                logging("Optimizer was not saved. Start from scratch.")

        return optimizer, optimizer_sparse

    def log_val(val_loss, step, compute):
        logging("-" * 100)
        log_str = ("| Eval {:3d} at step {:>8d} | time: {:5.2f}s "
                   "| valid loss {:5.2f}".format(
                       step // args.eval_interval,
                       step,
                       (time.time() - eval_start_time),
                       val_loss,
                   ))
        log_str += " | bpc {:9.5f}".format(val_loss / math.log(2))
        logging(log_str)
        logging("-" * 100)

    def epoch_loop(
        epoch,
        model,
        optimizers,
        schedulers,
    ):
        nonlocal train_step

        # Turn on training mode which enables dropout.
        if isinstance(model, nn.DataParallel):
            parent_model = model.module
        else:
            parent_model = model
        optimizer, optimizer_sparse = optimizers
        scheduler, scheduler_sparse = schedulers

        # global train_step, best_val_loss, eval_start_time, log_start_time
        train_losses = []
        model.train()
        if args.batch_chunk > 1:
            mems = [tuple() for _ in range(args.batch_chunk)]
        else:
            mems = tuple()
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter

        log_start_time = time.time()
        best_val_loss = float("Infinity")
        for batch, (data, target, seq_len) in enumerate(train_iter):
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                target_chunks = torch.chunk(target, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    target_i = target_chunks[i].contiguous()
                    ret = model(data_i, target_i, *mems[i])
                    loss, mems[i] = ret[0], ret[1:]
                    loss = loss.float().mean().type_as(loss) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    train_losses.append(loss.float().item())
            else:
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.float().mean().type_as(loss)
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_losses.append(loss.float().item())

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            parent_model.compute += openai_compute(
                non_emb_param_count(parent_model, nseries), data.numel(), 1)

            # step-wise learning rate annealing
            train_step += 1
            parent_model.training_steps += 1

            # check for yet-to-thaw parameters
            if getattr(parent_model, "freeze_countdown", 0) > 0:
                parent_model.freeze_countdown -= 1

                # if this is the last step
                if parent_model.freeze_countdown == 0:
                    for parameter in parent_model.parameters():
                        parameter.requires_grad = True
                    logging("thawing all parameters")

            if args.scheduler in ["cosine", "constant", "dev_perf"]:
                # linear warmup stage
                if train_step < args.warmup_step:
                    curr_lr = args.lr * train_step / args.warmup_step
                    optimizer.param_groups = curr_lr
                else:
                    if args.scheduler == "cosine":
                        scheduler.step(train_step)
            elif args.scheduler == "inv_sqrt":
                scheduler.step(train_step)

            if train_step % args.log_interval == 0:
                cur_loss = np.mean(train_losses)
                elapsed = time.time() - log_start_time
                log_str = ("| epoch {:3d} step {:>8d} "
                           "| {:>6d} batches "
                           "| lr {:.3g} "
                           "| ms/batch {:5.2f} "
                           "| loss {:5.2f}".format(
                               epoch,
                               train_step,
                               batch + 1,
                               optimizer.param_groups[0]["lr"],
                               elapsed * 1000 / args.log_interval,
                               cur_loss,
                           ))

                log_str += " | bpc {:9.5f}".format(cur_loss / math.log(2))
                logging(log_str)

                train_losses = []
                log_start_time = time.time()

            if train_step % args.eval_interval == 0:
                val_loss = evaluate(va_iter, model)
                log_val(val_loss,
                        step=train_step,
                        compute=parent_model.compute)
                # Save the model if the validation loss is the best we've seen so
                # far.
                if not best_val_loss or val_loss < best_val_loss:
                    best_val_loss = val_loss
                    if not args.debug:
                        if args.fp16:
                            with open(
                                    os.path.join(args.work_dir,
                                                 "amp_checkpoint.pt"),
                                    "wb",
                            ) as f:
                                checkpoint = {
                                    "model": model.state_dict(),
                                    "optimizer": optimizer.state_dict(),
                                    "amp": amp.state_dict(),
                                }
                                torch.save(checkpoint, f)
                        else:
                            with open(os.path.join(args.work_dir, "model.pt"),
                                      "wb") as f:
                                torch.save(parent_model, f)
                            with open(
                                    os.path.join(args.work_dir,
                                                 "optimizer.pt"),
                                    "wb",
                            ) as f:
                                torch.save(optimizer.state_dict(), f)

                # dev-performance based learning rate annealing
                if args.scheduler == "dev_perf":
                    scheduler.step(val_loss)

                eval_start_time = time.time()

            if train_step == args.max_step:
                break

    def expand_model(
        strategy,
        integration,
        integration_length,
        n_add,
        model: MemTransformerLM,
        optimizers,
        schedulers,
        tr_iter,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers
        scheduler, _ = schedulers
        if integration:
            if not integration_length or integration_length <= 0:
                warnings.warn(
                    f"integration {integration} passed but integration_length is {integration_length}"
                )
            else:
                logging(
                    f"applying integration strategy {integration} with integration length {integration_length}"
                )

        # pre-expansion validation
        logging(f"evaluating before expanding")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

        # infer example logits for reverse distillation
        if "reverse_distil" in integration:
            first_logits = get_original_batches(model, tr_iter,
                                                integration_length)

        # expansion
        logging(
            f"adding {n_add} layers before starting epoch {epoch} with method {strategy}"
        )
        new_layers = model.expand_layers(n_add,
                                         strategy=strategy,
                                         function=initialization_func)

        # optimizer update
        optimizer.add_param_group({
            "params":
            new_layers.parameters(),
            "lr":
            optimizer.param_groups[0]["lr"],
            "initial_lr":
            optimizer.param_groups[0]["initial_lr"],
        })
        scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"])

        # training loop for reverse distillation
        if "reverse_distil" in integration:
            fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                                  integration)

        # freezing parameters for frozen restart, we do this afterwards else the
        # new layers get copied also without grads
        if "freeze" in integration and integration_length > 0:
            for param_group in optimizer.param_groups[:-1]:
                for parameter in param_group["params"]:
                    parameter.requires_grad = False
            model.freeze_countdown = integration_length

        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    def expand_state(param, state):
        if param.shape != state.shape:
            ratios = [
                param.shape[i] // state.shape[i]
                for i in range(len(param.shape))
            ]
            return state.repeat(*ratios)
        else:
            return state

    def widen_model(
        strategy,
        ratio,
        model: MemTransformerLM,
        optimizers,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers

        # pre-expansion validation
        logging(f"evaluating before widening")

        # debug
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, compute=model.compute, step=step)

        # infer example logits for reverse distillation expansion
        logging(
            f"adding {ratio} layers before starting epoch {epoch} with method {strategy}"
        )
        model.add_heads(ratio, strategy=strategy, function=initialization_func)

        # optimizer update
        for param, states in optimizer.state.items():
            if isinstance(param, nn.Parameter):
                states["exp_avg"] = expand_state(param, states["exp_avg"])
                states["exp_avg_sq"] = expand_state(param,
                                                    states["exp_avg_sq"])

        # training loop for reverse distillation
        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    # reverse distillation trainer
    def fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                              integration):
        mse_loss = torch.nn.MSELoss()
        if "partial" in integration:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                new_layers, reload=False)
        else:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                model, reload=False)
        if args.cuda and args.fp16:
            model, distil_optimizer = amp.initialize(model,
                                                     distil_optimizer,
                                                     opt_level=args.fp16)

        model.train()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
        else:
            mems = None
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        for batch, (data, _, _) in enumerate(train_iter):
            if batch == len(first_logits):
                break
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    logits, mems[i] = model._forward(data_i, mems=mems[i])
                    target_logits = first_logits[i][batch].to(logits.device)
                    loss = mse_loss(logits, target_logits) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss,
                                            distil_optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
            else:
                logits, mems = model._forward(data, mems=mems)
                target_logits = first_logits[batch].to(logits.device)
                loss = mse_loss(logits, target_logits)
                if args.fp16:
                    with amp.scale_loss(loss, distil_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(distil_optimizer), args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            distil_optimizer.step()

    ###################################################################################
    #
    # main()
    #
    args.tied = not args.not_tied

    if args.d_embed < 0:
        args.d_embed = args.d_model

    # Validate `--fp16` option
    if args.fp16:
        if not args.cuda:
            print("WARNING: --fp16 requires --cuda, ignoring --fp16 option")
            args.fp16 = False
        else:
            try:
                from apex import amp

                if args.fp16 == "O1":
                    amp.register_half_function(torch, "einsum")
            except:
                print("WARNING: apex not installed, ignoring --fp16 option")
                args.fp16 = False

    device = torch.device("cuda" if args.cuda else "cpu")

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run "
                "with --cuda ")
        else:
            torch.cuda.manual_seed_all(args.seed)

    ############################################################################
    # Logging
    ############################################################################

    assert args.ext_len >= 0, "extended context length must be non-negative"
    assert args.d_batch % args.batch_chunk == 0

    args.work_dir = "{}-{}".format(args.work_dir, args.dataset)
    args.work_dir = os.path.join(args.work_dir, time.strftime("%Y%m%d-%H%M%S"))
    logging = create_exp_dir(
        args.work_dir,
        scripts_to_save=["train_ts.py", "mem_transformer.py"],
        debug=args.debug,
    )

    ############################################################################
    # Load data
    ############################################################################
    time_series = get_time_series(args.datadir, args.dataset)
    nseries = len(time_series.vocab)
    args.n_token = nseries

    eval_batch_size = 20
    tr_iter = time_series.get_iterator(
        "train",
        args.d_batch,
        args.tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    va_iter = time_series.get_iterator(
        "valid",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    te_iter = time_series.get_iterator(
        "test",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )

    cutoffs, tie_projs = [], [False]

    ############################################################################
    # Define model
    ############################################################################

    initialization_func = partial(
        weights_init,
        init=args.init,
        init_range=args.init_range,
        init_std=args.init_std,
        proj_init_std=args.proj_init_std,
    )

    if args.restart and not args.fp16:
        if args.restart_from is not None:
            model_name = f"model_{args.restart_from}.pt"
        else:
            model_name = "model.pt"
        model_file_name = os.path.join(args.restart_dir, model_name)
        logging(f"reloading {model_file_name}")
        with open(model_file_name, "rb") as f:
            model = torch.load(f)
        # backwards compatibility with older saves
        if isinstance(model, nn.DataParallel):
            model = model.module
        model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs)
        if not args.fp16:
            model = model.float()
        model.apply(update_dropout)
        model.apply(update_dropatt)

    else:
        model = MemTransformerLM(
            nseries,
            args.n_layer,
            args.n_head,
            args.d_model,
            args.d_head,
            args.d_inner,
            args.dropout,
            args.dropatt,
            tie_weight=args.tied,
            d_embed=args.d_embed,
            div_val=args.div_val,
            tie_projs=tie_projs,
            pre_lnorm=args.pre_lnorm,
            tgt_len=args.tgt_len,
            ext_len=args.ext_len,
            mem_len=args.mem_len,
            cutoffs=cutoffs,
            same_length=args.same_length,
            clamp_len=args.clamp_len,
        )
        model.apply(initialization_func)

        # debug
        # model.word_emb.apply(initialization_func)
        # ensure embedding init is not overridden by out_layer in case of
        # weight sharing
    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = non_emb_param_count(model, nseries)

    logging("=" * 100)
    for k, v in args.__dict__.items():
        logging("    - {} : {}".format(k, v))
    logging("=" * 100)
    logging("#params = {}".format(args.n_all_param))
    logging("#non emb params = {}".format(args.n_nonemb_param))

    para_model = parallelize_model(model, args)
    optimizers = build_optimizer(para_model,
                                 args,
                                 reload=args.restart and not args.fp16)
    optimizer, optimizer_sparse = optimizers
    schedulers = build_scheduler(optimizers, args)
    scheduler, scheduler_sparse = schedulers

    if args.cuda and args.fp16:
        para_model, optimizer = amp.initialize(para_model,
                                               optimizer,
                                               opt_level=args.fp16)

        if args.restart:
            if args.restart_from is not None:
                checkpoint_name = f"amp_checkpoint_{args.restart_from}.pt"
            else:
                checkpoint_name = "amp_checkpoint.pt"
            with open(os.path.join(args.work_dir, checkpoint_name), "rb") as f:
                checkpoint = torch.load(f)
                model.load_state_dict(checkpoint["model"])
                optimizer.load_state_dict(checkpoint["optimizer"])
                amp.load_state_dict(checkpoint["amp"])

    ############################################################################
    # Training loop
    ############################################################################

    # Loop over epochs.
    if args.reset_lr:
        # then they're different and we use train_step only for the new lr
        # scheduling
        train_step = 0
        optimizer.defaults["lr"] = args.lr
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.lr
            param_group["initial_lr"] = args.lr
        scheduler.base_lrs = [args.lr] * len(scheduler.base_lrs)
    else:
        train_step = model.training_steps

    best_val_loss = None

    # Reload previous step number in case of default_args.restart
    if train_step > 0:
        logging(f"restarting from step {train_step}")

    log_start_time = time.time()
    eval_start_time = time.time()

    def run_training():
        nonlocal train_step

        for epoch in itertools.count(start=first_epoch):
            # we check before the training loop; expanding at epoch 0 means
            # before training (for debug purposes)
            if args.expand and str(epoch - 1) in args.expansion_dict:
                n_add = int(args.expansion_dict[str(epoch - 1)])
                expand_model(
                    args.expand,
                    args.integration,
                    args.integration_length,
                    n_add,
                    model,
                    optimizers,
                    schedulers,
                    tr_iter,
                    va_iter,
                    epoch,
                    train_step,
                )
            if args.widen and str(epoch - 1) in args.widen_dict:
                ratio = int(args.widen_dict[str(epoch - 1)])
                widen_model(
                    args.widen,
                    ratio,
                    model,
                    optimizers,
                    va_iter,
                    epoch,
                    train_step,
                )
            epoch_loop(epoch, para_model, optimizers, schedulers)
            if train_step >= args.max_step:
                logging("-" * 100)
                logging("End of training")
                break
            if not args.debug and args.log_first_epochs:
                if epoch <= args.log_first_epochs:
                    logging(f"saving model at the end of epoch {epoch}")
                    if args.fp16:
                        with open(
                                os.path.join(args.work_dir,
                                             f"amp_checkpoint_{epoch}.pt"),
                                "wb",
                        ) as f:
                            checkpoint = {
                                "model": model.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "amp": amp.state_dict(),
                            }
                            torch.save(checkpoint, f)
                    else:
                        with open(
                                os.path.join(args.work_dir,
                                             f"model_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(model, f)
                        with open(
                                os.path.join(args.work_dir,
                                             f"optimizer_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(optimizer.state_dict(), f)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        if args.restart_from:
            first_epoch = args.restart_from + 1
            print(f"restarting from epoch {first_epoch}")
        else:
            first_epoch = 1
        run_training()

    except KeyboardInterrupt:
        logging("-" * 100)
        logging("Exiting from training early")

    # Load the best model.
    if args.fp16:
        with open(os.path.join(args.work_dir, "amp_checkpoint.pt"), "rb") as f:
            checkpoint = torch.load(f)
            model.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            amp.load_state_dict(checkpoint["amp"])
    else:
        with open(os.path.join(args.work_dir, "model.pt"), "rb") as f:
            model = torch.load(f)
    para_model = model.to(device)

    # Run on test data.
    test_loss = evaluate(te_iter, para_model)
    logging("=" * 100)
    logging("| End of training | test loss {:5.2f} | test bpc {:9.5f}".format(
        test_loss, test_loss / math.log(2)))
    logging("=" * 100)