コード例 #1
0
ファイル: train.py プロジェクト: wangkanger/transformer-xl
def evaluate(eval_iter, split, train_step=-1):
    global best_val_loss
    eval_start_time = time.time()
    # Turn on evaluation mode which disables dropout.
    model.eval()

    # Have to unwrap twice: DDP & FP16
    model_to_reset = model.module.module if args.fp16 else model.module
    # If the model does not use memory at all, make the ext_len longer.
    # Otherwise, make the mem_len longer and keep the ext_len the same.
    if args.mem_len == 0:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len + args.tgt_len - args.eval_tgt_len,
            args.mem_len)
    else:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len,
            args.mem_len + args.tgt_len - args.eval_tgt_len)

    # Evaluation
    total_len, total_loss = 0, 0.
    with torch.no_grad():
        mems = tuple()
        bar = tqdm.tqdm(eval_iter, leave=False, desc="Eval")
        for i, (data, target, seq_len) in enumerate(bar):
            if args.max_eval_steps > 0:
                if i >= args.max_eval_steps:
                    break
            ret = model(data, target, *mems)
            loss, mems = ret[0], ret[1:]
            loss = loss.mean()
            bar.set_description(f'Eval loss {loss:.2f}')
            total_loss += seq_len * loss.float().item()
            total_len += seq_len

    # Switch back to the training mode
    model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    model.train()

    # Log all the things.
    mean_loss = total_loss / total_len
    logger.info('-' * 100)
    log_str = (
        f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | '
        + f'time: {time.time() - eval_start_time:5.2f}s ' +
        f'| {split} loss {mean_loss:5.2f}')
    if args.dataset in ['enwik8', 'text8']:
        log_str += f' | bpc {mean_loss / math.log(2):9.5f}'
    else:
        log_str += f' | {split} ppl {math.exp(mean_loss):9.3f}'
    logger.info(log_str)
    logger.info('-' * 100)
    log_tb(f'learning/{split}_loss', mean_loss)
    log_tb(f'learning/{split}_ppl', math.exp(mean_loss))

    # Update checkpoint if validation loss improved.
    if split == 'val' and (not best_val_loss or mean_loss < best_val_loss):
        logger.info('Saving checkpoint for new best loss')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best')
        best_val_loss = mean_loss
コード例 #2
0
def evaluate_and_log(optimizer, eval_iter, split, train_step=-1):
    global best_val_loss
    eval_start_time = time.time()

    # Have to unwrap DDP & FP16, if using.
    def unwrap(module):
        if isinstance(module, MemTransformerLM):
            return module
        return unwrap(module.module)

    model_to_reset = unwrap(model)
    # If the model does not use memory at all, make the ext_len longer.
    # Otherwise, make the mem_len longer and keep the ext_len the same.
    if args.mem_len == 0:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len + args.tgt_len - args.eval_tgt_len,
            args.mem_len)
    else:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len,
            args.mem_len + args.tgt_len - args.eval_tgt_len)

    total_loss, total_len = evaluate(model, eval_iter, split,
                                     args.max_eval_steps)

    # Switch back to the training mode
    model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    model.train()

    # Log all the things.
    mean_loss = total_loss / total_len
    logger.info('-' * 100)
    log_str = (
        f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | '
        + f'time: {time.time() - eval_start_time:5.2f}s ' +
        f'| {split} loss {mean_loss:5.2f}')
    if args.dataset in ['enwik8', 'text8']:
        log_str += f' | bpc {mean_loss / math.log(2):9.5f}'
    else:
        log_str += f' | {split} ppl {math.exp(mean_loss):9.3f}'
    logger.info(log_str)
    logger.info('-' * 100)
    log_tb(f'learning/{split}_loss', mean_loss)
    log_tb(f'learning/{split}_ppl', math.exp(mean_loss))

    # Update checkpoint if validation loss improved.
    if split == 'val' and (not best_val_loss or mean_loss < best_val_loss):
        logger.info('Saving checkpoint for new best loss')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best')
        best_val_loss = mean_loss
コード例 #3
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with '
            f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(
            backend=args.dist_backend,
            #init_method=args.dist_url,
            #world_size=util.get_world_size()
        )
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    g.logger.info("creating new model")
    g.state = TrainState(args)
    g.state.model = MemTransformerLM(g.ntokens,
                                     args.n_layer,
                                     args.n_head,
                                     args.d_model,
                                     args.d_head,
                                     args.d_inner,
                                     args.dropout,
                                     args.dropatt,
                                     tie_weight=args.tied,
                                     d_embed=args.d_embed,
                                     div_val=args.div_val,
                                     tie_projs=g.tie_projs,
                                     pre_lnorm=args.pre_lnorm,
                                     tgt_len=args.tgt_len,
                                     ext_len=args.ext_len,
                                     mem_len=args.mem_len,
                                     cutoffs=g.cutoffs,
                                     same_length=args.same_length,
                                     attn_type=args.attn_type,
                                     clamp_len=args.clamp_len,
                                     sample_softmax=args.sample_softmax,
                                     freeze_below=args.freeze_below)
    g.state.model.to(g.device)
    optimizer_setup(g.state)
    if args.checkpoint:
        if args.checkpoint_secondary:
            g.logger.info(f"restoring extra checkpoint")
            util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                         args.checkpoint_secondary,
                                         args.optim_state_dict)
        g.logger.info(f"Restoring model from {args.checkpoint}" +
                      f" and optimizer from {args.optim_state_dict}" if args.
                      optim_state_dict else "")
        util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                     args.checkpoint, args.optim_state_dict)

    else:
        g.state.model.apply(weights_init)
        # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.word_emb.apply(weights_init)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    if g.state.args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=g.state.args.static_loss_scale,
            dynamic_loss_scale=g.state.args.dynamic_loss_scale,
            dynamic_loss_args={'init_scale': 2**16},
            verbose=False)

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_tokens // 1e6, eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        g.state.scheduler: LRFinder = LRFinder(optimizer,
                                               args.max_tokens,
                                               init_value=args.lr / 1e3)
    else:
        assert args.scheduler == 'constant'
        g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.token_count}")
            model.train()

            log_tb('sizes/batch_size', args.batch_size)
            log_tb('sizes/seq_size', args.tgt_len)

            if g.state.partial_epoch:
                # reuse previously loaded tr_iter and states
                assert g.state.tr_iter is not None
                assert g.state.mems is not None
            else:
                g.state.tr_iter = g.corpus.get_dist_iterator(
                    'train',
                    rank=util.get_global_rank(),
                    max_rank=util.get_world_size(),
                    bsz=args.batch_size,
                    bptt=args.tgt_len,
                    device=g.device,
                    ext_len=args.ext_len,
                    skip_files=g.args.skip_files)
                g.state.mems = tuple()
            g.state.last_epoch = epoch

            log_start_time = time.time()
            tokens_per_epoch = 0
            for batch, (data, target, seq_len) in enumerate(g.state.tr_iter):
                # assert seq_len == data.shape[0]
                # for i in range(1, data.shape[0]):
                #     assert torch.all(torch.eq(data[i], target[i - 1]))
                #     break

                # print(g.state.token_count, data)

                if g.state.train_step % args.eval_interval == 0:
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-1',
                                     generate_text=False,
                                     reset_mems_interval=1)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-2',
                                     generate_text=False,
                                     reset_mems_interval=2)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-3',
                                     generate_text=False,
                                     reset_mems_interval=3)
                    evaluate_and_log(model, g.va_iter, 'val')
                    if g.va_custom_iter:
                        evaluate_and_log(g.state.model,
                                         g.va_custom_iter,
                                         g.args.valid_custom,
                                         generate_text=False)

                batch_total = torch.tensor(data.shape[1]).to(g.device)
                if args.local:  # TODO(y): factor out (need way to see if dist was inited)
                    batch_total = batch_total.sum()
                else:
                    batch_total = util.dist_sum_tensor(
                        batch_total)  # global batch size
                batch_total = util.toscalar(batch_total)

                should_log = (g.state.train_step < args.verbose_log_steps) or \
                             (g.state.train_step + 1) % args.log_interval == 0

                model.zero_grad()

                ret = model(data, target, *g.state.mems)
                loss, g.state.mems = ret[0], ret[1:]

                loss: torch.Tensor = loss.float().mean().type_as(loss)
                with timeit('backwards', noop=not should_log):
                    if args.fp16:
                        optimizer.backward(loss)
                    else:
                        loss.backward()
                loss0 = util.toscalar(loss)
                util.record('loss', loss0)

                util.record('params', torch.sum(util.flat_param(model)).item())
                losses.append(loss0)
                accumulated_loss += loss0

                if args.fp16:
                    optimizer.clip_master_grads(args.clip)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.clip)

                # step-wise learning rate annealing
                if hasattr(optimizer, 'overflow') and optimizer.overflow:
                    g.logger.info("skipped iteration")
                else:
                    if args.scheduler in ['cosine', 'constant', 'dev_perf']:
                        # linear warmup stage
                        if g.state.token_count < args.warmup_tokens:
                            curr_lr = args.lr * float(
                                g.state.token_count) / args.warmup_tokens
                            optimizer.param_groups[0]['lr'] = curr_lr
                        elif args.scheduler == 'cosine':
                            # Divide by 1e6 for numerical stability.
                            g.state.scheduler.step(g.state.token_count //
                                                   1000 // 1000)
                    else:
                        g.state.scheduler.step(g.state.token_count)

                optimizer.step()
                g.state.train_step += 1

                consumed_tokens = data.shape[0] * data.shape[1]
                world_size = int(os.environ.get("WORLD_SIZE", "8"))
                if world_size > 8:  # correction factor for multiple machines
                    consumed_tokens = consumed_tokens * (world_size // 8)
                tokens_per_epoch += consumed_tokens
                g.state.token_count += consumed_tokens
                g.token_count = g.state.token_count
                if g.state.token_count >= args.max_tokens:
                    g.state.partial_epoch = True
                    raise StopIteration  # break out of parent train loop

                if should_log:
                    elapsed_time = time.time() - log_start_time
                    elapsed_steps = g.state.train_step - g.state.last_log_step

                    # compute average loss over last logging interval
                    cur_loss = accumulated_loss / elapsed_steps
                    cur_loss_mean = util.dist_mean(cur_loss)
                    log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \
                              f'| {batch:>6d} batches ' \
                              f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \
                              f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \
                              f'| loss {cur_loss:5.2f}'
                    if args.dataset in ['enwik8', 'text8']:
                        log_str += f' | bpc {cur_loss / math.log(2):9.5f}'
                    else:
                        log_str += f' | ppl {math.exp(cur_loss):9.3f}'
                    g.logger.info(log_str)
                    log_tb('learning/epoch', epoch)
                    log_tb('_loss', cur_loss_mean)  # the most important thing
                    log_tb('learning/loss', cur_loss_mean)
                    log_tb('learning/ppl', math.exp(cur_loss_mean))

                    # currently step timings are not synchronized in multi-machine
                    # case (see #4). Can add torch.distributed.barrier() to get
                    # more accurate timings, but this may add slowness.
                    log_tb('times/step', 1000 * elapsed_time / elapsed_steps)
                    current_lr = optimizer.param_groups[0]['lr']

                    log_tb('learning/lr', current_lr)

                    # 32 is the "canonical" batch size
                    linear_scaling_factor = batch_total / 32  # TODO(y): merge logic from master
                    log_tb('learning/base_lr',
                           current_lr / linear_scaling_factor)
                    if args.optim == 'lamb':
                        log_lamb_rs(optimizer, g.event_writer,
                                    g.state.token_count)

                    time_per_batch = elapsed_time / elapsed_steps
                    time_per_sample = time_per_batch / args.batch_size
                    time_per_token = time_per_sample / args.tgt_len

                    log_tb('times/batches_per_sec', 1 / time_per_batch)
                    log_tb('times/samples_per_sec', 1 / time_per_sample)
                    log_tb('times/tokens_per_sec', 1 / time_per_token)

                    if str(g.device) == 'cuda':
                        log_tb("memory/allocated_gb",
                               torch.cuda.memory_allocated() / 1e9)
                        log_tb("memory/max_allocated_gb",
                               torch.cuda.max_memory_allocated() / 1e9)
                        log_tb("memory/cached_gb",
                               torch.cuda.memory_cached() / 1e9)
                        log_tb("memory/max_cached_gb",
                               torch.cuda.max_memory_cached() / 1e9)

                    accumulated_loss = 0
                    log_start_time = time.time()
                    g.state.last_log_step = g.state.train_step

            if args.checkpoint_each_epoch:
                g.logger.info(f'Saving checkpoint for epoch {epoch}')
                util.dist_save_checkpoint(model,
                                          optimizer,
                                          args.logdir,
                                          suffix=f'{epoch}')
            if tokens_per_epoch == 0:
                logging.info("Zero tokens in last epoch, breaking")

                break

            g.state.partial_epoch = False

    except KeyboardInterrupt:
        g.logger.info('-' * 100)
        g.logger.info('Exiting from training early')
    except StopIteration:
        pass

    return losses
コード例 #4
0
def evaluate_and_log(model: torch.nn.Module,
                     eval_iter,
                     split: str,
                     generate_text: bool = True,
                     reset_mems_interval: int = None):
    args = g.args
    state = g.state
    optimizer = g.state.optimizer
    eval_start_time = time.time()

    model_to_reset = util.unwrap_model(model)
    # If the model does not use memory at all, make the ext_len longer.
    # Otherwise, make the mem_len longer and keep the ext_len the same.
    if g.args.mem_len == 0:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len + args.tgt_len - args.eval_tgt_len,
            args.mem_len)
    else:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len,
            args.mem_len + args.tgt_len - args.eval_tgt_len)

    # Calculate metrics
    ret = evaluate(model,
                   eval_iter,
                   split,
                   args.max_eval_steps,
                   reset_mems_interval=reset_mems_interval)
    total_loss, accuracy_top1, accuracy_top5, MRR, total_len = \
        ret["total_loss"], ret["accuracy_top1"], ret["accuracy_top5"], ret["MRR_top5"], ret["total_len"]
    # Switch back to the training mode
    model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    model.train()

    # Log all the things.
    loss = total_loss / total_len
    mean_loss = util.dist_mean(loss)
    mean_accuracy_top1 = util.dist_mean(accuracy_top1)
    mean_accuracy_top5 = util.dist_mean(accuracy_top5)
    mean_MRR = util.dist_mean(MRR)
    g.logger.info('-' * 100)
    log_str = (
        f'| Eval {g.state.train_step // args.eval_interval:3d} at step {g.state.train_step:>8d} | '
        f'time: {time.time() - eval_start_time:5.2f}s '
        f'| {split} loss {loss:5.2f}')
    log_tb(f'learning/{split}_loss', mean_loss)
    if args.dataset in ['enwik8', 'text8']:
        log_str += f' | bpc {loss / math.log(2):9.5f}'
        log_tb(f'learning/{split}_bpc', mean_loss / math.log(2))
    elif args.dataset == 'git':
        log_str += f' | accuracy@1 {accuracy_top1:.2f} ' \
                   f'| accuracy@5 {accuracy_top5:.2f} ' \
                   f'| MRR@5 {MRR:.2f}'
        log_tb(f'learning/{split}_acc@1', mean_accuracy_top1)
        log_tb(f'learning/{split}_acc@5', mean_accuracy_top5)
        log_tb(f'learning/{split}_MRR@5', mean_MRR)
    else:
        log_str += f' | {split} ppl {math.exp(loss):9.3f}'
        log_tb(f'learning/{split}_ppl', math.exp(mean_loss))
    g.logger.info(log_str)
    g.logger.info('-' * 100)

    # Update checkpoint if validation loss improved.
    if split == 'val' and (not state.best_val_loss
                           or mean_loss < state.best_val_loss):
        g.logger.info('Saving checkpoint for new best loss')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best')
        state.best_val_loss = mean_loss
コード例 #5
0
def main():
    global global_token_count, event_writer, train_step, train_loss, last_log_step, \
        best_val_loss, epoch, model

    if args.local_rank > 0:
        pass  # skip shutdown when rank is explicitly set + not zero rank
    else:
        os.system('shutdown -c')

    if not args.local:
        logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}'
        )
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)

    # log model info
    n_all_param = sum([p.nelement() for p in model.parameters()])
    log_tb('sizes/params', n_all_param)
    n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    log_tb('sizes/non_emb_params', n_nonemb_param)
    logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # optimizer
    if args.optim.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.mom)
    elif args.optim.lower() == 'lamb':
        optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd)
    else:
        assert args.optim.lower() == 'adam'
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         args.max_tokens //
                                                         1e6,
                                                         eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        scheduler = LRFinder(optimizer,
                             args.max_tokens,
                             init_value=args.lr / 1e3)
    elif args.scheduler == 'constant':
        pass

    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing

    if args.checkpoint:
        if global_rank == 0:
            util.restore_from_checkpoint(model=model,
                                         checkpoint_fn=args.checkpoint)

    model = model.to(device)
    if args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={'init_scale': 2**16},
                                   verbose=False)

    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  #, find_unused_parameters=True)

    if global_rank == 0:
        event_writer = SummaryWriter(args.logdir)

    event_writer.add_text('args', str(args))

    # test checkpoint writing
    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}')

    # Loop over epochs.
    train_step = 0
    train_loss = 0
    last_log_step = 0
    best_val_loss = None
    va_iter, te_iter = [
        corpus.get_dist_iterator(split,
                                 global_rank,
                                 max_rank,
                                 args.batch_size * 2,
                                 args.tgt_len,
                                 device=device,
                                 ext_len=args.ext_len)
        for split in ('valid', 'test')
    ]

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=1):
            train(va_iter, optimizer, scheduler)
    except KeyboardInterrupt:
        logger.info('-' * 100)
        logger.info('Exiting from training early')
    except StopIteration:
        pass

    # Eval one more time.
    evaluate_and_log(optimizer, va_iter, 'val', train_step=-1)

    # Load the best saved model.
    logger.info("Loading best checkpoint")
    model_file = os.path.join(args.logdir, 'model-best.pt')
    if os.path.exists(model_file):
        with open(model_file, 'rb') as model_f:
            with timeit('load'):
                if args.local:
                    model = torch.load(model_f)
                else:
                    model = torch.load(model_f,
                                       map_location=lambda storage, loc:
                                       storage.cuda(args.local_rank))
                    model = DistributedDataParallel(
                        model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)
    else:
        logger.warn('no model file, using current model for loss')

    # Run on test data.
    evaluate_and_log(optimizer, te_iter, 'test', -1)
コード例 #6
0
def train(va_iter, optimizer, scheduler):
    global global_token_count, event_writer, train_loss, best_val_loss, \
        train_step, last_log_step, epoch
    # Turn on training mode which enables dropout.
    model.train()

    log_tb('sizes/batch_size', args.batch_size)
    log_tb('sizes/seq_size', args.tgt_len)

    tr_iter = corpus.get_dist_iterator('train',
                                       global_rank,
                                       max_rank,
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
    mems = tuple()
    log_start_time = time.time()
    for batch, (data, target, seq_len) in enumerate(tr_iter):
        assert seq_len == data.shape[0]
        for i in range(1, data.shape[0]):
            assert torch.all(torch.eq(data[i], target[i - 1]))
            break

        batch_total = torch.tensor(data.shape[1]).to(device)
        batch_total = batch_total.to(device)  # needed for NCCL sync
        if args.local:
            batch_total = batch_total.sum()
        else:
            batch_total = util.dist_sum_tensor(
                batch_total)  # global batch size
            batch_total = util.toscalar(batch_total)

        total_tokens = batch_total * seq_len
        should_log = train_step < args.verbose_log_steps or train_step % args.log_interval == 0

        global_token_count += total_tokens
        model.zero_grad()
        ret = model(data, target, *mems)
        loss, mems = ret[0], ret[1:]
        loss = loss.float().mean().type_as(loss)
        with timeit('backwards', noop=not should_log):
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
        train_loss += loss.float().item()

        if args.fp16:
            optimizer.clip_master_grads(args.clip)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()

        train_step += 1

        # step-wise learning rate annealing
        if args.fp16 and optimizer.overflow:
            logger.info("skipped iteration")
        else:
            if args.scheduler in ['cosine', 'constant', 'dev_perf']:
                # linear warmup stage
                if global_token_count < args.warmup_tokens:
                    curr_lr = args.lr * global_token_count / args.warmup_tokens
                    optimizer.param_groups[0]['lr'] = curr_lr
                elif args.scheduler == 'cosine':
                    # Divide by 1e6 for numerical stability.
                    scheduler.step(global_token_count // 1e6)
            else:
                scheduler.step(global_token_count)

        if should_log:
            elapsed_time = time.time() - log_start_time
            elapsed_steps = train_step - last_log_step

            # compute average loss over last logging interval
            cur_loss = train_loss / elapsed_steps
            log_str = f'| epoch {epoch:3d} step {train_step:>8d} | {batch:>6d} batches | lr {optimizer.param_groups[0]["lr"]:.3g} ' \
                      f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} | loss {cur_loss:5.2f}'
            if args.dataset in ['enwik8', 'text8']:
                log_str += f' | bpc {cur_loss / math.log(2):9.5f}'
            else:
                log_str += f' | ppl {math.exp(cur_loss):9.3f}'
            logger.info(log_str)
            log_tb('learning/epoch', epoch)
            log_tb('_loss', cur_loss)  # the most important thing
            log_tb('learning/loss', cur_loss)
            log_tb('learning/ppl', math.exp(cur_loss))

            # currently step timings are not synchronized in multi-machine
            # case (see #4). Can add torch.distributed.barrier() to get
            # more accurate timings, but this may add slowness.
            log_tb('times/step', 1000 * elapsed_time / elapsed_steps)
            current_lr = optimizer.param_groups[0]['lr']

            log_tb('learning/lr', current_lr)

            # 32 is the "canonical" batch size
            linear_scaling_factor = batch_total / 32
            log_tb('learning/base_lr', current_lr / linear_scaling_factor)
            if args.optim == 'lamb':
                log_lamb_rs(optimizer, event_writer, global_token_count)

            time_per_batch = elapsed_time / elapsed_steps
            time_per_sample = time_per_batch / args.batch_size
            time_per_token = time_per_sample / args.tgt_len

            log_tb('times/batches_per_sec', 1 / time_per_batch)
            log_tb('times/samples_per_sec', 1 / time_per_sample)
            log_tb('times/tokens_per_sec', 1 / time_per_token)

            if str(device) == 'cuda':
                log_tb("memory/allocated_gb",
                       torch.cuda.memory_allocated() / 1e9)
                log_tb("memory/max_allocated_gb",
                       torch.cuda.max_memory_allocated() / 1e9)
                log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9)
                log_tb("memory/max_cached_gb",
                       torch.cuda.max_memory_cached() / 1e9)

            train_loss = 0
            log_start_time = time.time()
            last_log_step = train_step

        if train_step % args.eval_interval == 0:
            evaluate_and_log(optimizer, va_iter, 'val', train_step)

        if global_token_count >= args.max_tokens:
            if args.eta_min == 0:
                raise StopIteration
            logger.info('End of schedule, staying at current LR')
            args.scheduler = 'constant'

    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model,
                                  optimizer,
                                  args.logdir,
                                  suffix=f'{epoch}')