コード例 #1
0
ファイル: train.py プロジェクト: mbencherif/neural_sp
def main():

    args = parse_args_train(sys.argv[1:])

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # for multi-GPUs
    if args.n_gpus > 1:
        batch_size = args.batch_size * args.n_gpus
        accum_grad_n_steps = max(1, args.accum_grad_n_steps // args.n_gpus)
    else:
        batch_size = args.batch_size
        accum_grad_n_steps = args.accum_grad_n_steps

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=batch_size,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        shuffle=args.shuffle,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=batch_size,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = [
        Dataset(corpus=args.corpus,
                tsv_path=s,
                dict_path=args.dict,
                nlsyms=args.nlsyms,
                unit=args.unit,
                wp_model=args.wp_model,
                batch_size=1,
                bptt=args.bptt,
                backward=args.backward,
                serialize=args.serialize) for s in args.eval_sets
    ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        args.save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(args.save_path)
    else:
        dir_name = set_lm_name(args)
        args.save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        args.save_path = set_save_path(args.save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(args.save_path, 'train.log'), stdout=args.stdout)

    # Model setting
    model = build_lm(args, args.save_path)

    if not args.resume:
        # Save nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(args.save_path,
                                                  'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(args.save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(args.save_path,
                                                    'wp.model'))

        for k, v in sorted(args.items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info('torch version: %s' % str(torch.__version__))
        logger.info(model)

    # Set optimizer
    resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0
    optimizer = set_optimizer(
        model,
        'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
        args.lr, args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = args.lm_type in ['transformer', 'transformer_xl']
    scheduler = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        model_size=args.get('transformer_d_model', 0),
        factor=args.lr_factor,
        noam=args.optimizer == 'noam',
        save_checkpoints_topk=10 if is_transformer else 1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, scheduler)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            scheduler.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # GPU setting
    args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp, scaler = None, None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=not (is_transformer or args.cudnn_benchmark),
            benchmark=not is_transformer and args.cudnn_benchmark)

        # Mixed precision training setting
        if args.use_apex:
            if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
                scaler = torch.cuda.amp.GradScaler()
            else:
                from apex import amp
                model, scheduler.optimizer = amp.initialize(
                    model, scheduler.optimizer, opt_level=args.train_dtype)
                amp.init()
                if args.resume:
                    load_checkpoint(args.resume, amp=amp)
        model.cuda()
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
    else:
        model = CPUWrapperLM(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(args, model)
    if args.resume:
        n_steps = scheduler.n_steps * accum_grad_n_steps
        reporter.resume(n_steps, resume_epoch)

    # Save conf file as a yaml file
    if not args.resume:
        save_config(args, os.path.join(args.save_path, 'conf.yml'))
        # NOTE: save after reporter for wandb ID

    hidden = None
    start_time_train = time.time()
    for ep in range(resume_epoch, args.n_epochs):
        for ys_train, is_new_epoch in train_set:
            hidden = train(model, train_set, dev_set, scheduler, reporter,
                           logger, args, accum_grad_n_steps, amp, scaler,
                           hidden)

        # Save checkpoint and validate model per epoch
        if reporter.n_epochs + 1 < args.eval_start_epoch:
            scheduler.epoch()  # lr decay
            reporter.epoch()  # plot

            # Save model
            scheduler.save_checkpoint(model,
                                      args.save_path,
                                      remove_old=not is_transformer
                                      and args.remove_old_checkpoints,
                                      amp=amp)
        else:
            start_time_eval = time.time()
            # dev
            model.module.reset_length(args.bptt)
            ppl_dev, _ = eval_ppl([model.module],
                                  dev_set,
                                  batch_size=1,
                                  bptt=args.bptt)
            model.module.reset_length(args.bptt)
            scheduler.epoch(ppl_dev)  # lr decay
            reporter.epoch(ppl_dev, name='perplexity')  # plot
            reporter.add_scalar('dev/perplexity', ppl_dev)
            logger.info('PPL (%s, ep:%d): %.2f' %
                        (dev_set.set, reporter.n_epochs, ppl_dev))

            if scheduler.is_topk or is_transformer:
                # Save model
                scheduler.save_checkpoint(model,
                                          args.save_path,
                                          remove_old=not is_transformer
                                          and args.remove_old_checkpoints,
                                          amp=amp)

                # test
                ppl_test_avg = 0.
                for eval_set in eval_sets:
                    model.module.reset_length(args.bptt)
                    ppl_test, _ = eval_ppl([model.module],
                                           eval_set,
                                           batch_size=1,
                                           bptt=args.bptt)
                    model.module.reset_length(args.bptt)
                    logger.info('PPL (%s, ep:%d): %.2f' %
                                (eval_set.set, reporter.n_epochs, ppl_test))
                    ppl_test_avg += ppl_test
                if len(eval_sets) > 0:
                    logger.info(
                        'PPL (avg., ep:%d): %.2f' %
                        (reporter.n_epochs, ppl_test_avg / len(eval_sets)))

            logger.info('Evaluation time: %.2f min' %
                        ((time.time() - start_time_eval) / 60))

            # Early stopping
            if scheduler.is_early_stop:
                break

            # Convert to fine-tuning stage
            if reporter.n_epochs == args.convert_to_sgd_epoch:
                scheduler.convert_to_sgd(model,
                                         args.lr,
                                         args.weight_decay,
                                         decay_type='always',
                                         decay_rate=0.5)

        if reporter.n_epochs >= args.n_epochs:
            break

    logger.info('Total time: %.2f hour' %
                ((time.time() - start_time_train) / 3600))
    reporter.close()

    return args.save_path
コード例 #2
0
def main():

    args = parse_args_train(sys.argv[1:])

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Load dataset
    batch_size = args.batch_size * args.n_gpus if args.n_gpus >= 1 else args.batch_size
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=batch_size,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        shuffle=args.shuffle,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=batch_size,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = [
        Dataset(corpus=args.corpus,
                tsv_path=s,
                dict_path=args.dict,
                nlsyms=args.nlsyms,
                unit=args.unit,
                wp_model=args.wp_model,
                batch_size=1,
                bptt=args.bptt,
                backward=args.backward,
                serialize=args.serialize) for s in args.eval_sets
    ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_lm_name(args)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(save_path, 'train.log'), stdout=args.stdout)

    # Model setting
    model = build_lm(args, save_path)

    if not args.resume:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))

        # Save the nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

    # Set optimizer
    resume_epoch = 0
    if args.resume:
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model,
            'sgd' if epoch > args.convert_to_sgd_epoch else args.optimizer,
            args.lr, args.weight_decay)
    else:
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = args.lm_type in ['transformer', 'transformer_xl']
    optimizer = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        model_size=getattr(args, 'transformer_d_model', 0),
        factor=args.lr_factor,
        noam=is_transformer,
        save_checkpoints_topk=1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, optimizer)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            optimizer.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # GPU setting
    use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp = None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=not (is_transformer or args.cudnn_benchmark),
            benchmark=args.cudnn_benchmark)
        model.cuda()

        # Mix precision training setting
        if use_apex:
            from apex import amp
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
            amp.init()
            if args.resume:
                load_checkpoint(args.resume, amp=amp)
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
    else:
        model = CPUWrapperLM(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_steps = 0
    n_steps = optimizer.n_steps * args.accum_grad_n_steps
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()
        accum_n_steps += 1

        loss, hidden, observation = model(ys_train, hidden)
        reporter.add(observation)
        if use_apex:
            with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        loss.detach()  # Trancate the graph
        if args.accum_grad_n_steps == 1 or accum_n_steps >= args.accum_grad_n_steps:
            if args.clip_grad_norm > 0:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.parameters(), args.clip_grad_norm)
                reporter.add_tensorboard_scalar('total_norm', total_norm)
            optimizer.step()
            optimizer.zero_grad()
            accum_n_steps = 0
        loss_train = loss.item()
        del loss
        hidden = model.module.repackage_state(hidden)
        reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
        # NOTE: loss/acc/ppl are already added in the model
        reporter.step()
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))
        n_steps += 1
        # NOTE: n_steps is different from the step counter in Noam Optimizer

        if n_steps % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next(bptt=args.bptt)[0]
            loss, _, observation = model(ys_dev, None, is_eval=True)
            reporter.add(observation, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" %
                (n_steps, optimizer.n_epochs + train_set.epoch_detail,
                 loss_train, loss_dev, optimizer.lr, ys_train.shape[0],
                 duration_step / 60))
            start_time_step = time.time()

        # Save fugures of loss and accuracy
        if n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()  # lr decay
                reporter.epoch()  # plot

                # Save the model
                optimizer.save_checkpoint(model,
                                          save_path,
                                          remove_old=not is_transformer,
                                          amp=amp)
            else:
                start_time_eval = time.time()
                # dev
                model.module.reset_length(args.bptt)
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                model.module.reset_length(args.bptt)
                optimizer.epoch(ppl_dev)  # lr decay
                reporter.epoch(ppl_dev, name='perplexity')  # plot
                logger.info('PPL (%s, ep:%d): %.2f' %
                            (dev_set.set, optimizer.n_epochs, ppl_dev))

                if optimizer.is_topk or is_transformer:
                    # Save the model
                    optimizer.save_checkpoint(model,
                                              save_path,
                                              remove_old=not is_transformer,
                                              amp=amp)

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        model.module.reset_length(args.bptt)
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        model.module.reset_length(args.bptt)
                        logger.info(
                            'PPL (%s, ep:%d): %.2f' %
                            (eval_set.set, optimizer.n_epochs, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg., ep:%d): %.2f' %
                                    (optimizer.n_epochs,
                                     ppl_test_avg / len(eval_sets)))

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    optimizer.convert_to_sgd(model,
                                             args.lr,
                                             args.weight_decay,
                                             decay_type='always',
                                             decay_rate=0.5)

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs >= args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
コード例 #3
0
                (reporter.n_epochs + 1, (time.time() - start_time_epoch) / 60))
    if args.local_rank == 0:
        pbar_epoch.close()


def spmd_main(args):
    # These are the parameters used to initialize the process group
    env_dict = {
        key: os.environ[key]
        for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
    }
    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
    dist.init_process_group(backend=args.dist_backend)
    print(f"[{os.getpid()}] world_size = {dist.get_world_size()}, " +
          f"rank = {dist.get_rank()}, backend={dist.get_backend()}")

    main(args)

    # Tear down the process group
    dist.destroy_process_group()


if __name__ == '__main__':
    args = parse_args_train(sys.argv[1:])
    args.distributed = args.n_gpus > 1 and args.local_world_size > 1

    if args.distributed:
        spmd_main(args)
    else:
        main(args)