Esempio n. 1
0
def setup_train(i, corpus, args):
    """Setup training.

    Handles CPU, single GPU, and distributed training.

    Args:
        i: The process index. Since one process per GPU, this is also
            the GPU index. For single GPU or CPU this is set to 0.
        corpus: The corpus for training.
        args: Arguments from argparse and main().
    """
    args.device = torch.device(args.device.type, i)

    # Find rank among all processes.
    args.rank = args.node_rank * args.gpu_per_node + i

    log = Logger(i, args.tensorboard_dir)
    log.train_add_text('arguments', str(args))
    log.valid_add_text('arguments', str(args))

    if args.dist:
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.world_size,
                                rank=args.rank)
        torch.cuda.set_device(args.rank)

    # Initialize model
    log("|  Loading model...")
    model = get_model(corpus.vocab, args)
    model.to(args.device)

    args.total_param = count_param(model)
    if hasattr(model, 'layer_pool'):
        args.layer_param = sum(
            [count_param(layer) for layer in model.layer_pool])
    elif hasattr(model, 'layer'):
        args.layer_param = count_param(model.layer)

    string = f"|  Model:\n{model}\n"
    string += f"|  Total parameters: {args.total_param}\n"
    string += f"|  Parameters without embedding and pre-softmax linear: {args.layer_param}"
    log(string)
    log.train_add_text('arguments', string)
    log.valid_add_text('arguments', string)

    # Create optimizer and scheduler.
    optimizer, scheduler = get_optimizer_scheduler(model, args)

    if args.fp16:
        print("|  Floating point 16 precision setting:\n", end='')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
    if args.dist:
        model = DistributedDataParallel(model,
                                        device_ids=[i],
                                        find_unused_parameters=True)

    resume_step = 0
    resume_epoch = 0
    if args.checkpoint is not None:
        log("|  Loading checkpoint...")
        if args.fp16:
            resume_step, resume_epoch = load_checkpoint(
                args.checkpoint, args.device, model, optimizer, scheduler, amp)
        else:
            resume_step, resume_epoch = load_checkpoint(
                args.checkpoint, args.device, model, optimizer, scheduler)

        def update_dropout(module):
            if hasattr(module, 'dropout'):
                model.dropout = args.dropout
            if hasattr(module, 'attn_dropout'):
                model.attn_dropout = args.attn_dropout

        model.apply(update_dropout)
    else:
        model.apply(reset_parameters)  # Initialize parameters

    # Get DataLoader
    log("|  Processing data...")
    train_loader = get_loader(corpus.train, corpus.vocab, args)
    if args.valid is not None:
        valid_loader = get_eval_loader(corpus.valid, corpus.vocab, args)

    log(f"|  Training on {socket.gethostname()} with rank {args.rank}.", True)

    def train(step, epoch, best_loss):
        model.train()
        optimizer.zero_grad()

        train_loader.dataset.set_seed(epoch)
        log.init_epoch(step, epoch, train_loader.dataset.total_target)
        epoch_loss = 0
        epoch_num_target = 0
        for batch_num, batch in enumerate(train_loader):
            # TODO debug
            f = batch['feature'].data.numpy()
            t = batch['target'].data.numpy()
            n = batch['num_target']
            vocab = corpus.vocab
            # TODO print out data to test
            # feat = np.transpose(f)
            # for data in feat:
            #     print(vocab.to_text(data))
            # continue
            # TODO test dataloading

            num_target = sum(batch['num_target'])
            epoch_num_target += num_target
            log.num_target += num_target
            log.batch_size += len(batch['num_target'])
            try:
                feature = batch['feature'].to(args.device)
                target = batch['target'].to(args.device)

                assert (target != vocab.pad_idx
                        ).sum() == num_target  # TODO remove debug check

                loss = model(feature, target)
                assert loss.dtype == torch.float32  # TODO remove debug check
                batch_loss = loss.item()
                epoch_loss += batch_loss
                log.loss += batch_loss

                loss = loss / num_target
                loss = loss / args.step_freq

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    log.oom += 1
                    print(
                        f"== Rank {args.rank}: Training out of memory. Skipping this batch. =="
                    )
                    # Release memory
                    if 'scaled_loss' in locals():
                        del scaled_loss
                    if 'loss' in locals():
                        del loss
                    if 'feature' in locals():
                        del feature
                    if 'target' in locals():
                        del target
                    for param in model.parameters():
                        if param.grad is not None:
                            param.grad = None
                    if args.cuda:
                        torch.cuda.empty_cache()
                else:
                    raise e

            if (batch_num + 1) % args.step_freq == 0:
                step, epoch, best_loss = update(step, epoch, best_loss)
                if args.max_step is not None and step >= args.max_step:
                    break
        # Remaining batches that doesn't fit in update freq.
        if not args.trim_step and (batch_num + 1) % args.step_freq != 0:
            step, epoch, best_loss = update(step, epoch, best_loss)
        log.end_epoch(step, epoch)
        return step, epoch_loss / epoch_num_target, best_loss

    def update(step, epoch, best_loss):
        loss_scale = 1
        if args.fp16:
            loss_scale = amp._amp_state.loss_scalers[0]._loss_scale

        # Calculate norm of gradients. For logging.
        if args.log_norm:
            for name, param in model.named_parameters():
                if param.grad is None:
                    continue
                norm = param.grad.data.float().norm().item() / loss_scale
                log.train_add_scalar('norm/' + name, norm, step)

        # Clip gradient
        if args.fp16:
            norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                                  args.clip_norm)
        else:
            norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                  args.clip_norm)
        log.norm += norm
        log.clip_norm += min(args.clip_norm, norm)

        optimizer.step()
        optimizer.zero_grad()

        step += 1
        if scheduler is not None:
            if step < args.warmup_step:
                # Linear warmup
                warmup_lr = args.lr * step / args.warmup_step
                optimizer.param_groups[0]['lr'] = warmup_lr
            else:
                scheduler.step()

        lr = optimizer.param_groups[0]['lr']
        log.train(step, lr, loss_scale)

        if args.step_per_save != 0 and step % args.step_per_save == 0:
            if i == 0:
                path = os.path.join(args.checkpoint_dir,
                                    f'checkpoint-{epoch}-{step}.pt')
                save_checkpoint(path, step, epoch, model, optimizer, scheduler,
                                amp if args.fp16 else None)
                copyfile(
                    path,
                    os.path.join(args.checkpoint_dir, 'checkpoint_last.pt'))
            if args.dist:
                dist.barrier()

        if args.step_per_valid != 0 and step % args.step_per_valid == 0:
            # Eval on validation data.
            if args.valid is not None:
                best_loss = validate(best_loss)
        return step, epoch, best_loss

    def evaluate():
        model.eval()
        total_loss = 0
        total_target = 0
        total = valid_loader.dataset.total_target
        if i == 0:
            progress = tqdm(desc="Evaluating", total=total, unit=' token')
        for batch in valid_loader:
            # TODO debug
            f = batch['feature'].data.numpy()
            t = batch['target'].data.numpy()
            n = batch['num_target']
            vocab = corpus.vocab
            # TODO print out data to test
            # feat = np.transpose(f)
            # for data in feat:
            #     print(vocab.to_text(data))
            # continue
            # TODO test dataloading

            num_target = sum(batch['num_target'])
            total_target += num_target

            feature = batch['feature'].to(args.device)
            target = batch['target'].to(args.device)
            loss = model(feature, target)

            total_loss += loss.item()
            if i == 0:
                progress.update(num_target)
        if i == 0:
            progress.close()
        return total_loss / total_target

    def validate(best_loss):
        with torch.no_grad():
            loss = evaluate()
        log.valid(loss, step, epoch)
        if i == 0 and best_loss > loss:
            best_loss = loss
            best_path = os.path.join(args.checkpoint_dir, 'checkpoint_best.pt')
            save_checkpoint(best_path, step, epoch, model, optimizer,
                            scheduler, amp if args.fp16 else None)
        if args.dist:
            dist.barrier()

        log.valid_add_scalar('best loss', best_loss / math.log(2), step)
        log.valid_add_scalar('best ppl', 2**(best_loss / math.log(2)), step)
        return best_loss

    step = resume_step
    best_loss = math.inf
    # Start from epoch 1 or resume from next epoch
    for epoch in itertools.count(resume_epoch + 1):
        # Train on training data.
        step, loss, best_loss = train(step, epoch, best_loss)
        if args.max_step is not None and step >= args.max_step:
            break

        if args.epoch_per_valid != 0 and epoch % args.epoch_per_valid == 0:
            # Eval on validation data.
            if args.valid is not None:
                if args.dist:
                    dist.barrier()
                best_loss = validate(best_loss)

        # Saving checkpoint.
        if args.epoch_per_save != 0 and epoch % args.epoch_per_save == 0:
            if i == 0:
                path = os.path.join(args.checkpoint_dir,
                                    f'checkpoint-{epoch}-{step}.pt')
                save_checkpoint(path, step, epoch, model, optimizer, scheduler,
                                amp if args.fp16 else None)
                copyfile(
                    path,
                    os.path.join(args.checkpoint_dir, 'checkpoint_last.pt'))
            if args.dist:
                dist.barrier()

        # Delete old checkpoints.
        if i == 0 and (args.keep_step is not None
                       or args.keep_epoch is not None):
            for filename in os.listdir(args.checkpoint_dir):
                if re.match(r'checkpoint-\d+-\d+\.pt', filename):
                    file_epoch, file_step = re.split(r'[-.]', filename)[1:3]
                    if args.keep_step is not None and int(file_step) <= (
                            step - args.keep_step):
                        os.remove(os.path.join(args.checkpoint_dir, filename))
                    if args.keep_epoch is not None and int(file_epoch) <= (
                            epoch - args.keep_epoch):
                        os.remove(os.path.join(args.checkpoint_dir, filename))
        if args.dist:
            dist.barrier()

        if args.max_epoch is not None and epoch >= args.max_epoch:
            break