Exemplo n.º 1
0
def train(args, model, device, train_loader, optimizer, epoch, event_writer):
    model.train()
    tqdm_bar = tqdm.tqdm(train_loader)
    for batch_idx, (data, target) in enumerate(tqdm_bar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            step = batch_idx * len(data) + (epoch-1) * len(train_loader.dataset)
            log_lamb_rs(optimizer, event_writer, step)
            event_writer.add_scalar('loss', loss.item(), step)
            tqdm_bar.set_description(
                f'Train epoch {epoch} Loss: {loss.item():.6f}')
Exemplo n.º 2
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with '
            f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(
            backend=args.dist_backend,
            #init_method=args.dist_url,
            #world_size=util.get_world_size()
        )
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    g.logger.info("creating new model")
    g.state = TrainState(args)
    g.state.model = MemTransformerLM(g.ntokens,
                                     args.n_layer,
                                     args.n_head,
                                     args.d_model,
                                     args.d_head,
                                     args.d_inner,
                                     args.dropout,
                                     args.dropatt,
                                     tie_weight=args.tied,
                                     d_embed=args.d_embed,
                                     div_val=args.div_val,
                                     tie_projs=g.tie_projs,
                                     pre_lnorm=args.pre_lnorm,
                                     tgt_len=args.tgt_len,
                                     ext_len=args.ext_len,
                                     mem_len=args.mem_len,
                                     cutoffs=g.cutoffs,
                                     same_length=args.same_length,
                                     attn_type=args.attn_type,
                                     clamp_len=args.clamp_len,
                                     sample_softmax=args.sample_softmax,
                                     freeze_below=args.freeze_below)
    g.state.model.to(g.device)
    optimizer_setup(g.state)
    if args.checkpoint:
        if args.checkpoint_secondary:
            g.logger.info(f"restoring extra checkpoint")
            util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                         args.checkpoint_secondary,
                                         args.optim_state_dict)
        g.logger.info(f"Restoring model from {args.checkpoint}" +
                      f" and optimizer from {args.optim_state_dict}" if args.
                      optim_state_dict else "")
        util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                     args.checkpoint, args.optim_state_dict)

    else:
        g.state.model.apply(weights_init)
        # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.word_emb.apply(weights_init)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    if g.state.args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=g.state.args.static_loss_scale,
            dynamic_loss_scale=g.state.args.dynamic_loss_scale,
            dynamic_loss_args={'init_scale': 2**16},
            verbose=False)

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_tokens // 1e6, eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        g.state.scheduler: LRFinder = LRFinder(optimizer,
                                               args.max_tokens,
                                               init_value=args.lr / 1e3)
    else:
        assert args.scheduler == 'constant'
        g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.token_count}")
            model.train()

            log_tb('sizes/batch_size', args.batch_size)
            log_tb('sizes/seq_size', args.tgt_len)

            if g.state.partial_epoch:
                # reuse previously loaded tr_iter and states
                assert g.state.tr_iter is not None
                assert g.state.mems is not None
            else:
                g.state.tr_iter = g.corpus.get_dist_iterator(
                    'train',
                    rank=util.get_global_rank(),
                    max_rank=util.get_world_size(),
                    bsz=args.batch_size,
                    bptt=args.tgt_len,
                    device=g.device,
                    ext_len=args.ext_len,
                    skip_files=g.args.skip_files)
                g.state.mems = tuple()
            g.state.last_epoch = epoch

            log_start_time = time.time()
            tokens_per_epoch = 0
            for batch, (data, target, seq_len) in enumerate(g.state.tr_iter):
                # assert seq_len == data.shape[0]
                # for i in range(1, data.shape[0]):
                #     assert torch.all(torch.eq(data[i], target[i - 1]))
                #     break

                # print(g.state.token_count, data)

                if g.state.train_step % args.eval_interval == 0:
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-1',
                                     generate_text=False,
                                     reset_mems_interval=1)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-2',
                                     generate_text=False,
                                     reset_mems_interval=2)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-3',
                                     generate_text=False,
                                     reset_mems_interval=3)
                    evaluate_and_log(model, g.va_iter, 'val')
                    if g.va_custom_iter:
                        evaluate_and_log(g.state.model,
                                         g.va_custom_iter,
                                         g.args.valid_custom,
                                         generate_text=False)

                batch_total = torch.tensor(data.shape[1]).to(g.device)
                if args.local:  # TODO(y): factor out (need way to see if dist was inited)
                    batch_total = batch_total.sum()
                else:
                    batch_total = util.dist_sum_tensor(
                        batch_total)  # global batch size
                batch_total = util.toscalar(batch_total)

                should_log = (g.state.train_step < args.verbose_log_steps) or \
                             (g.state.train_step + 1) % args.log_interval == 0

                model.zero_grad()

                ret = model(data, target, *g.state.mems)
                loss, g.state.mems = ret[0], ret[1:]

                loss: torch.Tensor = loss.float().mean().type_as(loss)
                with timeit('backwards', noop=not should_log):
                    if args.fp16:
                        optimizer.backward(loss)
                    else:
                        loss.backward()
                loss0 = util.toscalar(loss)
                util.record('loss', loss0)

                util.record('params', torch.sum(util.flat_param(model)).item())
                losses.append(loss0)
                accumulated_loss += loss0

                if args.fp16:
                    optimizer.clip_master_grads(args.clip)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.clip)

                # step-wise learning rate annealing
                if hasattr(optimizer, 'overflow') and optimizer.overflow:
                    g.logger.info("skipped iteration")
                else:
                    if args.scheduler in ['cosine', 'constant', 'dev_perf']:
                        # linear warmup stage
                        if g.state.token_count < args.warmup_tokens:
                            curr_lr = args.lr * float(
                                g.state.token_count) / args.warmup_tokens
                            optimizer.param_groups[0]['lr'] = curr_lr
                        elif args.scheduler == 'cosine':
                            # Divide by 1e6 for numerical stability.
                            g.state.scheduler.step(g.state.token_count //
                                                   1000 // 1000)
                    else:
                        g.state.scheduler.step(g.state.token_count)

                optimizer.step()
                g.state.train_step += 1

                consumed_tokens = data.shape[0] * data.shape[1]
                world_size = int(os.environ.get("WORLD_SIZE", "8"))
                if world_size > 8:  # correction factor for multiple machines
                    consumed_tokens = consumed_tokens * (world_size // 8)
                tokens_per_epoch += consumed_tokens
                g.state.token_count += consumed_tokens
                g.token_count = g.state.token_count
                if g.state.token_count >= args.max_tokens:
                    g.state.partial_epoch = True
                    raise StopIteration  # break out of parent train loop

                if should_log:
                    elapsed_time = time.time() - log_start_time
                    elapsed_steps = g.state.train_step - g.state.last_log_step

                    # compute average loss over last logging interval
                    cur_loss = accumulated_loss / elapsed_steps
                    cur_loss_mean = util.dist_mean(cur_loss)
                    log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \
                              f'| {batch:>6d} batches ' \
                              f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \
                              f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \
                              f'| loss {cur_loss:5.2f}'
                    if args.dataset in ['enwik8', 'text8']:
                        log_str += f' | bpc {cur_loss / math.log(2):9.5f}'
                    else:
                        log_str += f' | ppl {math.exp(cur_loss):9.3f}'
                    g.logger.info(log_str)
                    log_tb('learning/epoch', epoch)
                    log_tb('_loss', cur_loss_mean)  # the most important thing
                    log_tb('learning/loss', cur_loss_mean)
                    log_tb('learning/ppl', math.exp(cur_loss_mean))

                    # currently step timings are not synchronized in multi-machine
                    # case (see #4). Can add torch.distributed.barrier() to get
                    # more accurate timings, but this may add slowness.
                    log_tb('times/step', 1000 * elapsed_time / elapsed_steps)
                    current_lr = optimizer.param_groups[0]['lr']

                    log_tb('learning/lr', current_lr)

                    # 32 is the "canonical" batch size
                    linear_scaling_factor = batch_total / 32  # TODO(y): merge logic from master
                    log_tb('learning/base_lr',
                           current_lr / linear_scaling_factor)
                    if args.optim == 'lamb':
                        log_lamb_rs(optimizer, g.event_writer,
                                    g.state.token_count)

                    time_per_batch = elapsed_time / elapsed_steps
                    time_per_sample = time_per_batch / args.batch_size
                    time_per_token = time_per_sample / args.tgt_len

                    log_tb('times/batches_per_sec', 1 / time_per_batch)
                    log_tb('times/samples_per_sec', 1 / time_per_sample)
                    log_tb('times/tokens_per_sec', 1 / time_per_token)

                    if str(g.device) == 'cuda':
                        log_tb("memory/allocated_gb",
                               torch.cuda.memory_allocated() / 1e9)
                        log_tb("memory/max_allocated_gb",
                               torch.cuda.max_memory_allocated() / 1e9)
                        log_tb("memory/cached_gb",
                               torch.cuda.memory_cached() / 1e9)
                        log_tb("memory/max_cached_gb",
                               torch.cuda.max_memory_cached() / 1e9)

                    accumulated_loss = 0
                    log_start_time = time.time()
                    g.state.last_log_step = g.state.train_step

            if args.checkpoint_each_epoch:
                g.logger.info(f'Saving checkpoint for epoch {epoch}')
                util.dist_save_checkpoint(model,
                                          optimizer,
                                          args.logdir,
                                          suffix=f'{epoch}')
            if tokens_per_epoch == 0:
                logging.info("Zero tokens in last epoch, breaking")

                break

            g.state.partial_epoch = False

    except KeyboardInterrupt:
        g.logger.info('-' * 100)
        g.logger.info('Exiting from training early')
    except StopIteration:
        pass

    return losses
Exemplo n.º 3
0
def train(va_iter, optimizer, scheduler):
    global global_token_count, event_writer, train_loss, best_val_loss, \
        train_step, last_log_step, epoch
    # Turn on training mode which enables dropout.
    model.train()

    log_tb('sizes/batch_size', args.batch_size)
    log_tb('sizes/seq_size', args.tgt_len)

    tr_iter = corpus.get_dist_iterator('train',
                                       global_rank,
                                       max_rank,
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
    mems = tuple()
    log_start_time = time.time()
    for batch, (data, target, seq_len) in enumerate(tr_iter):
        assert seq_len == data.shape[0]
        for i in range(1, data.shape[0]):
            assert torch.all(torch.eq(data[i], target[i - 1]))
            break

        batch_total = torch.tensor(data.shape[1]).to(device)
        batch_total = batch_total.to(device)  # needed for NCCL sync
        if args.local:
            batch_total = batch_total.sum()
        else:
            batch_total = util.dist_sum_tensor(
                batch_total)  # global batch size
            batch_total = util.toscalar(batch_total)

        total_tokens = batch_total * seq_len
        should_log = train_step < args.verbose_log_steps or train_step % args.log_interval == 0

        global_token_count += total_tokens
        model.zero_grad()
        ret = model(data, target, *mems)
        loss, mems = ret[0], ret[1:]
        loss = loss.float().mean().type_as(loss)
        with timeit('backwards', noop=not should_log):
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
        train_loss += loss.float().item()

        if args.fp16:
            optimizer.clip_master_grads(args.clip)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()

        train_step += 1

        # step-wise learning rate annealing
        if args.fp16 and optimizer.overflow:
            logger.info("skipped iteration")
        else:
            if args.scheduler in ['cosine', 'constant', 'dev_perf']:
                # linear warmup stage
                if global_token_count < args.warmup_tokens:
                    curr_lr = args.lr * global_token_count / args.warmup_tokens
                    optimizer.param_groups[0]['lr'] = curr_lr
                elif args.scheduler == 'cosine':
                    # Divide by 1e6 for numerical stability.
                    scheduler.step(global_token_count // 1e6)
            else:
                scheduler.step(global_token_count)

        if should_log:
            elapsed_time = time.time() - log_start_time
            elapsed_steps = train_step - last_log_step

            # compute average loss over last logging interval
            cur_loss = train_loss / elapsed_steps
            log_str = f'| epoch {epoch:3d} step {train_step:>8d} | {batch:>6d} batches | lr {optimizer.param_groups[0]["lr"]:.3g} ' \
                      f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} | loss {cur_loss:5.2f}'
            if args.dataset in ['enwik8', 'text8']:
                log_str += f' | bpc {cur_loss / math.log(2):9.5f}'
            else:
                log_str += f' | ppl {math.exp(cur_loss):9.3f}'
            logger.info(log_str)
            log_tb('learning/epoch', epoch)
            log_tb('_loss', cur_loss)  # the most important thing
            log_tb('learning/loss', cur_loss)
            log_tb('learning/ppl', math.exp(cur_loss))

            # currently step timings are not synchronized in multi-machine
            # case (see #4). Can add torch.distributed.barrier() to get
            # more accurate timings, but this may add slowness.
            log_tb('times/step', 1000 * elapsed_time / elapsed_steps)
            current_lr = optimizer.param_groups[0]['lr']

            log_tb('learning/lr', current_lr)

            # 32 is the "canonical" batch size
            linear_scaling_factor = batch_total / 32
            log_tb('learning/base_lr', current_lr / linear_scaling_factor)
            if args.optim == 'lamb':
                log_lamb_rs(optimizer, event_writer, global_token_count)

            time_per_batch = elapsed_time / elapsed_steps
            time_per_sample = time_per_batch / args.batch_size
            time_per_token = time_per_sample / args.tgt_len

            log_tb('times/batches_per_sec', 1 / time_per_batch)
            log_tb('times/samples_per_sec', 1 / time_per_sample)
            log_tb('times/tokens_per_sec', 1 / time_per_token)

            if str(device) == 'cuda':
                log_tb("memory/allocated_gb",
                       torch.cuda.memory_allocated() / 1e9)
                log_tb("memory/max_allocated_gb",
                       torch.cuda.max_memory_allocated() / 1e9)
                log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9)
                log_tb("memory/max_cached_gb",
                       torch.cuda.max_memory_cached() / 1e9)

            train_loss = 0
            log_start_time = time.time()
            last_log_step = train_step

        if train_step % args.eval_interval == 0:
            evaluate_and_log(optimizer, va_iter, 'val', train_step)

        if global_token_count >= args.max_tokens:
            if args.eta_min == 0:
                raise StopIteration
            logger.info('End of schedule, staying at current LR')
            args.scheduler = 'constant'

    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model,
                                  optimizer,
                                  args.logdir,
                                  suffix=f'{epoch}')