コード例 #1
0
def train(args, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_dir = os.path.join("tensorboard", args.model_name)
        os.makedirs(tb_dir, exist_ok=True)
        tb_writer = SummaryWriter(tb_dir)

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = FusedLAMB(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer=optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=args.num_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(models=model,
                                          optimizers=optimizer,
                                          opt_level=args.fp16_opt_level,
                                          cast_model_outputs=torch.float16)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = DDP(
            model,
            message_size=250000000,
            gradient_predivide_factor=torch.distributed.get_world_size())

    train_dataset = LMDataset(corpus_path=args.corpus_path,
                              tokenizer=tokenizer,
                              local_rank=args.local_rank,
                              seq_len=args.max_seq_length,
                              vocab_size=args.vocab_size,
                              mask_prob=args.mask_prob)
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", args.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    iters = 0
    model.zero_grad()
    model.train()
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    while True:
        train_dataset.gen_segment()
        train_sampler = RandomSampler(
            train_dataset) if args.local_rank == -1 else DistributedSampler(
                train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size,
                                      num_workers=4,
                                      pin_memory=True)
        epoch_iterator = tqdm(
            train_dataloader,
            desc="Training (X iter) (XX / XX Steps) (Total Loss=X.X)\
                               (Generator Loss=X.X) (Discriminator Loss=X.X)",
            disable=args.local_rank not in [-1, 0])
        tr_loss = 0.0
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            input_ids, input_mask, segment_ids, lm_label_ids = batch
            gen_loss, disc_loss = model(input_ids, segment_ids, input_mask,
                                        lm_label_ids)

            loss = gen_loss + disc_loss
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            mean_loss = tr_loss * args.gradient_accumulation_steps / (step + 1)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                scheduler.step()  # learning rate warmup
                optimizer.step()
                for param in model.parameters():
                    param.grad = None
                global_step += 1

                epoch_iterator.set_description(
                    "Training (%d iter) (%d / %d Steps) (Mean Loss=%2.5f) (Generator Loss=%2.5f) (Discriminator Loss=%2.5f)"
                    % (iters, global_step, args.num_steps, mean_loss, gen_loss,
                       disc_loss / 50.0))

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('Mean_Loss', mean_loss, global_step)
                    tb_writer.add_scalar('Gen_Loss', gen_loss, global_step)
                    tb_writer.add_scalar('Disc_Loss', disc_loss / 50.0,
                                         global_step)

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    model_checkpoint = os.path.join(
                        args.output_dir,
                        args.model_name + '_' + str(global_step) + '.bin')
                    model_layer_checkpoint = os.path.join(
                        args.output_dir,
                        args.model_name + '_' + str(global_step) + '_disc.bin')
                    torch.save(model_to_save.state_dict(), model_checkpoint)
                    torch.save(model_to_save.discriminator.model.state_dict(),
                               model_layer_checkpoint)
                    logger.info("Saving model checkpoint to %s",
                                args.output_dir)
            if args.num_steps > 0 and global_step == args.num_steps:
                epoch_iterator.close()
                break
        if args.num_steps > 0 and global_step == args.num_steps:
            epoch_iterator.close()
            break
        iters += 1
    if args.local_rank in [-1, 0]:
        model_to_save = model.module if hasattr(model, 'module') else model
        model_checkpoint = os.path.join(
            args.output_dir, args.model_name + '_' + str(global_step) + '.bin')
        model_layer_checkpoint = os.path.join(
            args.output_dir,
            args.model_name + '_' + str(global_step) + '_disc.bin')
        torch.save(model_to_save.state_dict(), model_checkpoint)
        torch.save(model_to_save.discriminator.model.state_dict(),
                   model_layer_checkpoint)
        logger.info("Saving model checkpoint to %s", args.output_dir)
        logger.info("End Training!")
        tb_writer.close()
コード例 #2
0
ファイル: train_pu.py プロジェクト: sidnarayanan/grapple
            loss, _ = metrics.compute(yhat,
                                      y,
                                      w=weight,
                                      m=loss_mask,
                                      plot_m=qm,
                                      x=x)
            loss /= config.grad_acc
            if config.attention_band is not None:
                with amp.scale_loss(loss, opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if (n_batch + 1) % config.grad_acc == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
                opt.step()
                opt.zero_grad()
                ready_for_lr = True

            metrics_puppi.compute(p, y, w=weight, m=loss_mask, plot_m=qm)

            avg_loss_tensor += loss

            if config.beta:
                p, q = yhat[:, :, 0], yhat[:, :, 1]
                # logger.info(' '.join([str(x) for x in [p.max(), p.min(), q.max(), q.min()]]))
                yhat = p / (p + q + 1e-5)

            score = t2n(torch.clamp(yhat.squeeze(-1), 0, 1))
            charged_mask = ~batch['neutral_mask']
            score[charged_mask] = batch['y'][charged_mask]
コード例 #3
0
def main():
    logging.configure_logger('RNNT')
    logging.log_start(logging.constants.INIT_START)

    args = parse_args()

    assert(torch.cuda.is_available())
    assert args.prediction_frequency is None or args.prediction_frequency % args.log_frequency == 0

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    if args.seed is not None:
        logging.log_event(logging.constants.SEED, value=args.seed)
        torch.manual_seed(args.seed + args.local_rank)
        np.random.seed(args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)
        # np_rng is used for buckets generation, and needs the same seed on every worker
        np_rng = np.random.default_rng(seed=args.seed)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_duration_flags(cfg, args.max_duration)

    assert args.grad_accumulation_steps >= 1
    assert args.batch_size % args.grad_accumulation_steps == 0, f'{args.batch_size} % {args.grad_accumulation_steps} != 0'
    logging.log_event(logging.constants.GRADIENT_ACCUMULATION_STEPS, value=args.grad_accumulation_steps)
    batch_size = args.batch_size // args.grad_accumulation_steps

    logging.log_event(logging.constants.SUBMISSION_BENCHMARK, value=logging.constants.RNNT)
    logging.log_event(logging.constants.SUBMISSION_ORG, value='my-organization')
    logging.log_event(logging.constants.SUBMISSION_DIVISION, value=logging.constants.CLOSED) # closed or open
    logging.log_event(logging.constants.SUBMISSION_STATUS, value=logging.constants.ONPREM) # on-prem/cloud/research
    logging.log_event(logging.constants.SUBMISSION_PLATFORM, value='my platform')

    logging.log_end(logging.constants.INIT_STOP)
    if multi_gpu:
        torch.distributed.barrier()
    logging.log_start(logging.constants.RUN_START)
    if multi_gpu:
        torch.distributed.barrier()

    print_once('Setting up datasets...')
    (
        train_dataset_kw,
        train_features_kw,
        train_splicing_kw,
        train_specaugm_kw,
    ) = config.input(cfg, 'train')
    (
        val_dataset_kw,
        val_features_kw,
        val_splicing_kw,
        val_specaugm_kw,
    ) = config.input(cfg, 'val')

    logging.log_event(logging.constants.DATA_TRAIN_MAX_DURATION,
                      value=train_dataset_kw['max_duration'])
    logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MAX,
                      value=train_dataset_kw['speed_perturbation']['max_rate'])
    logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MIN,
                      value=train_dataset_kw['speed_perturbation']['min_rate'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_N,
                      value=train_specaugm_kw['freq_masks'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MIN,
                      value=train_specaugm_kw['min_freq'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MAX,
                      value=train_specaugm_kw['max_freq'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_N,
                      value=train_specaugm_kw['time_masks'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MIN,
                      value=train_specaugm_kw['min_time'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MAX,
                      value=train_specaugm_kw['max_time'])
    logging.log_event(logging.constants.GLOBAL_BATCH_SIZE,
                      value=batch_size * world_size * args.grad_accumulation_steps)

    tokenizer_kw = config.tokenizer(cfg)
    tokenizer = Tokenizer(**tokenizer_kw)

    class PermuteAudio(torch.nn.Module):
        def forward(self, x):
            return (x[0].permute(2, 0, 1), *x[1:])

    train_augmentations = torch.nn.Sequential(
        train_specaugm_kw and features.SpecAugment(optim_level=args.amp, **train_specaugm_kw) or torch.nn.Identity(),
        features.FrameSplicing(optim_level=args.amp, **train_splicing_kw),
        PermuteAudio(),
    )
    val_augmentations = torch.nn.Sequential(
        val_specaugm_kw and features.SpecAugment(optim_level=args.amp, **val_specaugm_kw) or torch.nn.Identity(),
        features.FrameSplicing(optim_level=args.amp, **val_splicing_kw),
        PermuteAudio(),
    )

    logging.log_event(logging.constants.DATA_TRAIN_NUM_BUCKETS, value=args.num_buckets)

    if args.num_buckets is not None:
        sampler = dali_sampler.BucketingSampler(
            args.num_buckets,
            batch_size,
            world_size,
            args.epochs,
            np_rng
        )
    else:
        sampler = dali_sampler.SimpleSampler()

    train_loader = DaliDataLoader(gpu_id=args.local_rank,
                                  dataset_path=args.dataset_dir,
                                  config_data=train_dataset_kw,
                                  config_features=train_features_kw,
                                  json_names=args.train_manifests,
                                  batch_size=batch_size,
                                  sampler=sampler,
                                  grad_accumulation_steps=args.grad_accumulation_steps,
                                  pipeline_type="train",
                                  device_type=args.dali_device,
                                  tokenizer=tokenizer)

    val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=args.val_batch_size,
                                    sampler=dali_sampler.SimpleSampler(),
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    tokenizer=tokenizer)

    train_feat_proc = train_augmentations
    val_feat_proc   = val_augmentations

    train_feat_proc.cuda()
    val_feat_proc.cuda()

    steps_per_epoch = len(train_loader) // args.grad_accumulation_steps

    logging.log_event(logging.constants.TRAIN_SAMPLES, value=train_loader.dataset_size)
    logging.log_event(logging.constants.EVAL_SAMPLES, value=val_loader.dataset_size)

    # set up the model
    rnnt_config = config.rnnt(cfg)
    logging.log_event(logging.constants.MODEL_WEIGHTS_INITIALIZATION_SCALE, value=args.weights_init_scale)
    if args.weights_init_scale is not None:
        rnnt_config['weights_init_scale'] = args.weights_init_scale
    if args.hidden_hidden_bias_scale is not None:
        rnnt_config['hidden_hidden_bias_scale'] = args.hidden_hidden_bias_scale
    model = RNNT(n_classes=tokenizer.num_labels + 1, **rnnt_config)
    model.cuda()
    blank_idx = tokenizer.num_labels
    loss_fn = RNNTLoss(blank_idx=blank_idx)
    logging.log_event(logging.constants.EVAL_MAX_PREDICTION_SYMBOLS, value=args.max_symbol_per_sample)
    greedy_decoder = RNNTGreedyDecoder( blank_idx=blank_idx,
                                        max_symbol_per_sample=args.max_symbol_per_sample)

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    opt_eps=1e-9
    logging.log_event(logging.constants.OPT_NAME, value='lamb')
    logging.log_event(logging.constants.OPT_BASE_LR, value=args.lr)
    logging.log_event(logging.constants.OPT_LAMB_EPSILON, value=opt_eps)
    logging.log_event(logging.constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=args.lr_exp_gamma)
    logging.log_event(logging.constants.OPT_LR_WARMUP_EPOCHS, value=args.warmup_epochs)
    logging.log_event(logging.constants.OPT_LAMB_LR_HOLD_EPOCHS, value=args.hold_epochs)
    logging.log_event(logging.constants.OPT_LAMB_BETA_1, value=args.beta1)
    logging.log_event(logging.constants.OPT_LAMB_BETA_2, value=args.beta2)
    logging.log_event(logging.constants.OPT_GRADIENT_CLIP_NORM, value=args.clip_norm)
    logging.log_event(logging.constants.OPT_LR_ALT_DECAY_FUNC, value=True)
    logging.log_event(logging.constants.OPT_LR_ALT_WARMUP_FUNC, value=True)
    logging.log_event(logging.constants.OPT_LAMB_LR_MIN, value=args.min_lr)
    logging.log_event(logging.constants.OPT_WEIGHT_DECAY, value=args.weight_decay)

    # optimization
    kw = {'params': model.param_groups(args.lr), 'lr': args.lr,
          'weight_decay': args.weight_decay}

    initial_lrs = [group['lr'] for group in kw['params']]

    print_once(f'Starting with LRs: {initial_lrs}')
    optimizer = FusedLAMB(betas=(args.beta1, args.beta2), eps=opt_eps, **kw)

    adjust_lr = lambda step, epoch: lr_policy(
        step, epoch, initial_lrs, optimizer, steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
        min_lr=args.min_lr, exp_gamma=args.lr_exp_gamma)

    if args.amp:
        model, optimizer = amp.initialize(
            models=model,
            optimizers=optimizer,
            opt_level='O1',
            max_loss_scale=512.0)

    if args.ema > 0:
        ema_model = copy.deepcopy(model).cuda()
    else:
        ema_model = None
    logging.log_event(logging.constants.MODEL_EVAL_EMA_FACTOR, value=args.ema)

    if multi_gpu:
        model = DistributedDataParallel(model)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'RNN-T',
                                args.keep_milestones, args.amp)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    last_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    # training loop
    model.train()
    for epoch in range(start_epoch + 1, args.epochs + 1):

        logging.log_start(logging.constants.BLOCK_START,
                          metadata=dict(first_epoch_num=epoch,
                                        epoch_count=1))
        logging.log_start(logging.constants.EPOCH_START,
                          metadata=dict(epoch_num=epoch))

        epoch_utts = 0
        accumulated_batches = 0
        epoch_start_time = time.time()

        for batch in train_loader:

            if accumulated_batches == 0:
                adjust_lr(step, epoch)
                optimizer.zero_grad()
                step_utts = 0
                step_start_time = time.time()
                all_feat_lens = []

            audio, audio_lens, txt, txt_lens = batch

            feats, feat_lens = train_feat_proc([audio, audio_lens])
            all_feat_lens += feat_lens

            log_probs, log_prob_lens = model(feats, feat_lens, txt, txt_lens)
            loss = loss_fn(log_probs[:, :log_prob_lens.max().item()],
                                      log_prob_lens, txt, txt_lens)

            loss /= args.grad_accumulation_steps

            del log_probs, log_prob_lens

            if torch.isnan(loss).any():
                print_once(f'WARNING: loss is NaN; skipping update')
            else:
                if args.amp:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                loss_item = loss.item()
                del loss
                step_utts += batch[0].size(0) * world_size
                epoch_utts += batch[0].size(0) * world_size
                accumulated_batches += 1

            if accumulated_batches % args.grad_accumulation_steps == 0:

                if args.clip_norm is not None:
                    torch.nn.utils.clip_grad_norm_(
                        getattr(model, 'module', model).parameters(),
                        max_norm=args.clip_norm,
                        norm_type=2)

                total_norm = 0.0

                try:
                    if args.log_norm:
                        for p in getattr(model, 'module', model).parameters():
                            param_norm = p.grad.data.norm(2)
                            total_norm += param_norm.item() ** 2
                        total_norm = total_norm ** (1. / 2)
                except AttributeError as e:
                    print_once(f'Exception happened: {e}')
                    total_norm = 0.0

                optimizer.step()
                apply_ema(model, ema_model, args.ema)

                if step % args.log_frequency == 0:

                    if args.prediction_frequency is None or step % args.prediction_frequency == 0:
                        preds = greedy_decoder.decode(model, feats, feat_lens)
                        wer, pred_utt, ref = greedy_wer(
                                preds,
                                txt,
                                txt_lens,
                                tokenizer.detokenize)
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')
                        wer = {'wer': 100 * wer}
                    else:
                        wer = {}

                    step_time = time.time() - step_start_time

                    log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch),
                        step, 'train',
                        {'loss': loss_item,
                         **wer,  # optional entry
                         'throughput': step_utts / step_time,
                         'took': step_time,
                         'grad-norm': total_norm,
                         'seq-len-min': min(all_feat_lens).item(),
                         'seq-len-max': max(all_feat_lens).item(),
                         'lrate': optimizer.param_groups[0]['lr']})

                step_start_time = time.time()

                step += 1
                accumulated_batches = 0
                # end of step

        logging.log_end(logging.constants.EPOCH_STOP,
                        metadata=dict(epoch_num=epoch))

        epoch_time = time.time() - epoch_start_time
        log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
                                          'took': epoch_time})

        if epoch % args.val_frequency == 0:
            wer = evaluate(epoch, step, val_loader, val_feat_proc,
                           tokenizer.detokenize, ema_model, loss_fn,
                           greedy_decoder, args.amp)

            last_wer = wer
            if wer < best_wer and epoch >= args.save_best_from:
                checkpointer.save(model, ema_model, optimizer, epoch,
                                  step, best_wer, is_best=True)
                best_wer = wer

        save_this_epoch = (args.save_frequency is not None and epoch % args.save_frequency == 0) \
                       or (epoch in args.keep_milestones)
        if save_this_epoch:
            checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)

        logging.log_end(logging.constants.BLOCK_STOP, metadata=dict(first_epoch_num=epoch))

        if last_wer <= args.target:
            logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'success'})
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time})

    if last_wer > args.target:
        logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'aborted'})

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, tokenizer.detokenize,
                 ema_model, loss_fn, greedy_decoder, args.amp)

    flush_log()
    if args.save_at_the_end:
        checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)