示例#1
0
def main():

    args = parse()
    args_pt = copy.deepcopy(args)
    args_teacher = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)
    recog_params = vars(args)

    # Automatically reduce batch size in multi-GPU setting
    if args.n_gpus > 1:
        args.batch_size -= 10
        args.print_step //= args.n_gpus

    # Compute subsampling factor
    subsample_factor = 1
    subsample_factor_sub1 = 1
    subsample_factor_sub2 = 1
    subsample = [int(s) for s in args.subsample.split('_')]
    if args.conv_poolings and 'conv' in args.enc_type:
        for p in args.conv_poolings.split('_'):
            subsample_factor *= int(p.split(',')[0].replace('(', ''))
    else:
        subsample_factor = np.prod(subsample)
    if args.train_set_sub1:
        if args.conv_poolings and 'conv' in args.enc_type:
            subsample_factor_sub1 = subsample_factor * np.prod(
                subsample[:args.enc_n_layers_sub1 - 1])
        else:
            subsample_factor_sub1 = subsample_factor
    if args.train_set_sub2:
        if args.conv_poolings and 'conv' in args.enc_type:
            subsample_factor_sub2 = subsample_factor * np.prod(
                subsample[:args.enc_n_layers_sub2 - 1])
        else:
            subsample_factor_sub2 = subsample_factor

    skip_thought = 'skip' in args.enc_type

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        tsv_path_sub1=args.train_set_sub1,
                        tsv_path_sub2=args.train_set_sub2,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_frames=args.min_n_frames,
                        max_n_frames=args.max_n_frames,
                        sort_by='input',
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        subsample_factor=subsample_factor,
                        subsample_factor_sub1=subsample_factor_sub1,
                        subsample_factor_sub2=subsample_factor_sub2,
                        discourse_aware=args.discourse_aware,
                        skip_thought=skip_thought)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      tsv_path_sub1=args.dev_set_sub1,
                      tsv_path_sub2=args.dev_set_sub2,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      batch_size=args.batch_size * args.n_gpus,
                      min_n_frames=args.min_n_frames,
                      max_n_frames=args.max_n_frames,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      subsample_factor=subsample_factor,
                      subsample_factor_sub1=subsample_factor_sub1,
                      subsample_factor_sub2=subsample_factor_sub2,
                      discourse_aware=args.discourse_aware,
                      skip_thought=skip_thought)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    discourse_aware=args.discourse_aware,
                    skip_thought=skip_thought,
                    is_test=True)
        ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim

    # Load a LM conf file for LM fusion & LM initialization
    if not args.resume and (args.lm_fusion or args.lm_init):
        if args.lm_fusion:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml'))
        elif args.lm_init:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_init), 'conf.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_asr_model_name(args, subsample_factor)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'),
                        key='training',
                        stdout=args.stdout)

    # Model setting
    model = Speech2Text(args, save_path) if not skip_thought else SkipThought(
        args, save_path)

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else
            conf['optimizer'], conf['lr'], conf['weight_decay'])

        # Wrap optimizer by learning rate scheduler
        noam = 'transformer' in conf['enc_type'] or conf[
            'dec_type'] == 'transformer'
        optimizer = LRScheduler(
            optimizer,
            conf['lr'],
            decay_type=conf['lr_decay_type'],
            decay_start_epoch=conf['lr_decay_start_epoch'],
            decay_rate=conf['lr_decay_rate'],
            decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'],
            early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'],
            warmup_start_lr=conf['warmup_start_lr'],
            warmup_n_steps=conf['warmup_n_steps'],
            model_size=conf['d_model'],
            factor=conf['lr_factor'],
            noam=noam)

        # Restore the last saved model
        model, optimizer = load_checkpoint(model,
                                           args.resume,
                                           optimizer,
                                           resume=True)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if epoch == conf['convert_to_sgd_epoch']:
            n_epochs = optimizer.n_epochs
            n_steps = optimizer.n_steps
            optimizer = set_optimizer(model, 'sgd', args.lr,
                                      conf['weight_decay'])
            optimizer = LRScheduler(optimizer,
                                    args.lr,
                                    decay_type='always',
                                    decay_start_epoch=0,
                                    decay_rate=0.5)
            optimizer._epoch = n_epochs
            optimizer._step = n_steps
            logger.info('========== Convert to SGD ==========')
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))
        if args.lm_fusion:
            save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2']:
            if getattr(args, 'dict' + sub):
                shutil.copy(getattr(args, 'dict' + sub),
                            os.path.join(save_path, 'dict' + sub + '.txt'))
            if getattr(args, 'unit' + sub) == 'wp':
                shutil.copy(getattr(args, 'wp_model' + sub),
                            os.path.join(save_path, 'wp' + sub + '.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.pretrained_model and os.path.isfile(args.pretrained_model):
            # Load the ASR model
            conf_pt = load_config(
                os.path.join(os.path.dirname(args.pretrained_model),
                             'conf.yml'))
            for k, v in conf_pt.items():
                setattr(args_pt, k, v)
            model_pt = Speech2Text(args_pt)
            model_pt = load_checkpoint(model_pt, args.pretrained_model)[0]

            # Overwrite parameters
            only_enc = (args.enc_n_layers != args_pt.enc_n_layers) or (
                args.unit != args_pt.unit) or args_pt.ctc_weight == 1
            param_dict = dict(model_pt.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if only_enc and 'enc' not in n:
                        continue
                    if args.lm_fusion_type == 'cache' and 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

        # Set optimizer
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

        # Wrap optimizer by learning rate scheduler
        noam = 'transformer' in args.enc_type or args.dec_type == 'transformer'
        optimizer = LRScheduler(
            optimizer,
            args.lr,
            decay_type=args.lr_decay_type,
            decay_start_epoch=args.lr_decay_start_epoch,
            decay_rate=args.lr_decay_rate,
            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
            warmup_start_lr=args.warmup_start_lr,
            warmup_n_steps=args.warmup_n_steps,
            model_size=args.d_model,
            factor=args.lr_factor,
            noam=noam)

    # Load the teacher ASR model
    teacher = None
    if args.teacher and os.path.isfile(args.teacher):
        conf_teacher = load_config(
            os.path.join(os.path.dirname(args.teacher), 'conf.yml'))
        for k, v in conf_teacher.items():
            setattr(args_teacher, k, v)
        # Setting for knowledge distillation
        args_teacher.ss_prob = 0
        args.lsm_prob = 0
        teacher = Speech2Text(args_teacher)
        teacher = load_checkpoint(teacher, args.teacher)[0]

    # Load the teacher LM
    teacher_lm = None
    if args.teacher_lm and os.path.isfile(args.teacher_lm):
        conf_lm = load_config(
            os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml'))
        args_lm = argparse.Namespace()
        for k, v in conf_lm.items():
            setattr(args_lm, k, v)
        teacher_lm = build_lm(args_lm)
        teacher_lm = load_checkpoint(teacher_lm, args.teacher_lm)[0]

    # GPU setting
    if args.n_gpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()
        if teacher is not None:
            teacher.cuda()
        if teacher_lm is not None:
            teacher_lm.cuda()

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path, tensorboard=True)

    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        if args.lmobj_weight > 0:
            tasks = ['ys.lmobj'] + tasks
        for sub in ['sub1', 'sub2']:
            if getattr(args, 'train_set_' + sub):
                if getattr(args, sub + '_weight') - getattr(
                        args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub] + tasks
                if getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_tokens = 0
    while True:
        # Compute loss in the training set
        batch_train, is_new_epoch = train_set.next()
        accum_n_tokens += sum([len(y) for y in batch_train['ys']])

        # Change mini-batch depending on task
        for task in tasks:
            if skip_thought:
                loss, reporter = model(batch_train['ys'],
                                       ys_prev=batch_train['ys_prev'],
                                       ys_next=batch_train['ys_next'],
                                       reporter=reporter)
            else:
                loss, reporter = model(batch_train,
                                       reporter=reporter,
                                       task=task,
                                       teacher=teacher,
                                       teacher_lm=teacher_lm)
            # loss /= args.accum_grad_n_steps
            if len(model.device_ids) > 1:
                loss.backward(torch.ones(len(model.device_ids)))
            else:
                loss.backward()
            loss.detach()  # Trancate the graph
            if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens:
                if args.clip_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                                   args.clip_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                accum_n_tokens = 0
            loss_train = loss.item()
            del loss
        reporter.step()

        if optimizer.n_steps % args.print_step == 0:
            # Compute loss in the dev set
            batch_dev = dev_set.next()[0]
            # Change mini-batch depending on task
            for task in tasks:
                if skip_thought:
                    loss, reporter = model(batch_dev['ys'],
                                           ys_prev=batch_dev['ys_prev'],
                                           ys_next=batch_dev['ys_next'],
                                           reporter=reporter,
                                           is_eval=True)
                else:
                    loss, reporter = model(batch_dev,
                                           reporter=reporter,
                                           task=task,
                                           is_eval=True)
                loss_dev = loss.item()
                del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            if args.input_type == 'speech':
                xlen = max(len(x) for x in batch_train['xs'])
                ylen = max(len(y) for y in batch_train['ys'])
            elif args.input_type == 'text':
                xlen = max(len(x) for x in batch_train['ys'])
                ylen = max(len(y) for y in batch_train['ys_sub1'])
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/xlen:%d/ylen:%d (%.2f min)"
                %
                (optimizer.n_steps, optimizer.n_epochs +
                 train_set.epoch_detail, loss_train, loss_dev, optimizer.lr,
                 len(batch_train['utt_ids']), xlen, ylen, duration_step / 60))
            start_time_step = time.time()
        pbar_epoch.update(len(batch_train['utt_ids']))

        # Save fugures of loss and accuracy
        if optimizer.n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()
                reporter.epoch()

                # Save the model
                save_checkpoint(model,
                                save_path,
                                optimizer,
                                optimizer.n_epochs,
                                remove_old_checkpoints=not noam)
            else:
                start_time_eval = time.time()
                # dev
                metric_dev = eval_epoch([model.module], dev_set, recog_params,
                                        args, optimizer.n_epochs + 1, logger)
                optimizer.epoch(metric_dev)
                reporter.epoch(metric_dev)

                if optimizer.is_best:
                    # Save the model
                    save_checkpoint(model,
                                    save_path,
                                    optimizer,
                                    optimizer.n_epochs,
                                    remove_old_checkpoints=not noam)

                    # test
                    for eval_set in eval_sets:
                        eval_epoch([model.module], eval_set, recog_params,
                                   args, optimizer.n_epochs, logger)

                    # start scheduled sampling
                    if args.ss_prob > 0:
                        model.module.scheduled_sampling_trigger()

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    n_epochs = optimizer.n_epochs
                    n_steps = optimizer.n_steps
                    optimizer = set_optimizer(model, 'sgd', args.lr,
                                              args.weight_decay)
                    optimizer = LRScheduler(optimizer,
                                            args.lr,
                                            decay_type='always',
                                            decay_start_epoch=0,
                                            decay_rate=0.5)
                    optimizer._epoch = n_epochs
                    optimizer._step = n_steps
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
示例#2
0
def main():

    args = parse()
    hvd.init()
    torch.cuda.set_device(hvd.local_rank())
    hvd_rank = hvd.rank()
    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)
    recog_params = vars(args)

    # Compute subsampling factor
    subsample_factor = 1

    subsample = [int(s) for s in args.subsample.split('_')]
    if args.conv_poolings and 'conv' in args.enc_type:
        for p in args.conv_poolings.split('_'):
            subsample_factor *= int(p.split(',')[0].replace('(', ''))
    else:
        subsample_factor = np.prod(subsample)

    skip_thought = 'skip' in args.enc_type
    batch_per_allreduce = args.batch_size
    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        tsv_path_sub1=args.train_set_sub1,
                        tsv_path_sub2=args.train_set_sub2,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        batch_size=args.batch_size,
                        n_epochs=args.n_epochs,
                        min_n_frames=args.min_n_frames,
                        max_n_frames=args.max_n_frames,
                        sort_by='no_sort',
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        subsample_factor=subsample_factor,
                        discourse_aware=args.discourse_aware,
                        skip_thought=skip_thought)

    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      tsv_path_sub1=args.dev_set_sub1,
                      tsv_path_sub2=args.dev_set_sub2,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      batch_size=args.batch_size,
                      min_n_frames=args.min_n_frames,
                      max_n_frames=args.max_n_frames,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      subsample_factor=subsample_factor,
                      discourse_aware=args.discourse_aware,
                      skip_thought=skip_thought)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    discourse_aware=args.discourse_aware,
                    skip_thought=skip_thought,
                    is_test=True)
        ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim
    # Horovod: use DistributedSampler to partition data among workers. Manually specify
    # `num_replicas=hvd.size()` and `rank=hvd.rank()`.
    train_loader = SeqDataloader(train_set,
                                 batch_size=args.batch_size,
                                 num_workers=1,
                                 distributed=True,
                                 num_stacks=args.n_stacks,
                                 num_splices=args.n_splices,
                                 num_skips=args.n_skips,
                                 pin_memory=False,
                                 shuffle=False)
    val_loader = SeqDataloader(dev_set,
                               batch_size=args.batch_size,
                               num_workers=1,
                               distributed=True,
                               num_stacks=args.n_stacks,
                               num_splices=args.n_splices,
                               num_skips=args.n_skips,
                               pin_memory=False,
                               shuffle=False)

    # Load a LM conf file for LM fusion & LM initialization
    if not args.resume and (args.lm_fusion or args.lm_init):
        if args.lm_fusion:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml'))
        elif args.lm_init:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_init), 'conf.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_asr_model_name(args, subsample_factor)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        if hvd_rank == 0:
            save_path = set_save_path(save_path)  # avoid overwriting
    # Set logger
    if hvd_rank == 0:
        logger = set_logger(os.path.join(save_path, 'train.log'),
                            key='training',
                            stdout=args.stdout)
        # Set process name
        logger.info('PID: %s' % os.getpid())
        logger.info('USERNAME: %s' % os.uname()[1])
        logger.info('NUMBER_DEVICES: %s' % hvd.size())

    setproctitle(args.job_name if args.job_name else dir_name)
    # Model setting
    model = Speech2Text(args, save_path)
    # GPU setting
    if args.n_gpus >= 1:
        torch.backends.cudnn.benchmark = True
        model.cuda()

    if args.resume:
        # Set optimizer
        epochs = int(args.resume.split('-')[-1])
        #optimizer = set_optimizer(model, 'sgd' if epochs >= conf['convert_to_sgd_epoch'] else conf['optimizer'],

        model, _ = load_checkpoint(model, args.resume, resume=True)
        optimizer = set_optimizer(model, 'sgd', conf['lr'],
                                  conf['weight_decay'])
        #broadcast
        optimizer = hvd.DistributedOptimizer(
            optimizer, named_parameters=model.named_parameters())

        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        # Wrap optimizer by learning rate scheduler
        noam = 'transformer' in args.enc_type or args.dec_type == 'transformer'
        optimizer = LRScheduler(
            optimizer,
            args.lr,
            decay_type=args.lr_decay_type,
            decay_start_epoch=args.lr_decay_start_epoch,
            decay_rate=args.lr_decay_rate,
            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
            warmup_start_lr=args.warmup_start_lr,
            warmup_n_steps=args.warmup_n_steps,
            model_size=args.d_model,
            factor=args.lr_factor,
            noam=noam)

    else:
        # Save the conf file as a yaml file
        if hvd_rank == 0:
            save_config(vars(args), os.path.join(save_path, 'conf.yml'))
        if args.lm_fusion:
            save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml'))

        if hvd_rank == 0:
            for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
                logger.info('%s: %s' % (k, str(v)))

            # Count total parameters
            for n in sorted(list(model.num_params_dict.keys())):
                n_params = model.num_params_dict[n]
                logger.info("%s %d" % (n, n_params))
            logger.info("Total %.2f M parameters" %
                        (model.total_parameters / 1000000))
            logger.info(model)

        # Set optimizer
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

        optimizer = hvd.DistributedOptimizer(
            optimizer,
            named_parameters=model.named_parameters(),
            compression=hvd.Compression.none,
            backward_passes_per_step=batch_per_allreduce)

        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        # Wrap optimizer by learning rate scheduler
        noam = 'transformer' in args.enc_type or args.dec_type == 'transformer'
        optimizer = LRScheduler(
            optimizer,
            args.lr,
            decay_type=args.lr_decay_type,
            decay_start_epoch=args.lr_decay_start_epoch,
            decay_rate=args.lr_decay_rate,
            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
            warmup_start_lr=args.warmup_start_lr,
            warmup_n_steps=args.warmup_n_steps,
            model_size=args.d_model,
            factor=args.lr_factor,
            noam=noam)
    # Set reporter
    reporter = Reporter(save_path)
    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    accum_n_tokens = 0

    verbose = 1 if hvd_rank == 0 else 0
    data_size = len(train_set)
    while True:
        model.train()
        with tqdm(total=data_size // hvd.size(),
                  desc='Train Epoch     #{}'.format(optimizer.n_epochs + 1),
                  disable=not verbose) as pbar_epoch:
            # Compute loss in the training set
            for _, batch_train in enumerate(train_loader):
                accum_n_tokens += sum([len(y) for y in batch_train['ys']])
                # Change mini-batch depending on task
                for task in tasks:
                    if skip_thought:
                        loss, reporter = model(batch_train['ys'],
                                               ys_prev=batch_train['ys_prev'],
                                               ys_next=batch_train['ys_next'],
                                               reporter=reporter)
                    else:
                        loss, reporter = model(batch_train, reporter, task)
                    loss.backward()
                    loss.detach()  # Trancate the graph
                    if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens:
                        if args.clip_grad_norm > 0:
                            total_norm = torch.nn.utils.clip_grad_norm_(
                                model.parameters(), args.clip_grad_norm)
                        optimizer.step()
                        optimizer.zero_grad()

                        accum_n_tokens = 0
                    loss_train = loss.item()
                    del loss

                if optimizer.n_steps % args.print_step == 0:
                    # Compute loss in the dev set
                    model.eval()
                    batch_dev = dev_set.next()[0]
                    # Change mini-batch depending on task
                    for task in tasks:
                        if skip_thought:
                            loss, reporter = model(
                                batch_dev['ys'],
                                ys_prev=batch_dev['ys_prev'],
                                ys_next=batch_dev['ys_next'],
                                reporter=reporter,
                                is_eval=True)
                        else:
                            loss, reporter = model(batch_dev,
                                                   reporter,
                                                   task,
                                                   is_eval=True)
                        loss_dev = loss.item()
                        del loss

                    duration_step = time.time() - start_time_step
                    if args.input_type == 'speech':
                        xlen = max(len(x) for x in batch_train['xs'])
                        ylen = max(len(y) for y in batch_train['ys'])
                    elif args.input_type == 'text':
                        xlen = max(len(x) for x in batch_train['ys'])
                        ylen = max(len(y) for y in batch_train['ys_sub1'])

                    if hvd_rank == 0:
                        logger.info(
                            "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/xlen:%d/ylen:%d (%.2f min)"
                            % (optimizer.n_steps,
                               optimizer.n_steps * args.batch_size /
                               (data_size / hvd.size()), loss_train, loss_dev,
                               optimizer.lr, len(batch_train['utt_ids']), xlen,
                               ylen, duration_step / 60))
                    start_time_step = time.time()
                pbar_epoch.update(len(batch_train['utt_ids']))

                # Save fugures of loss and accuracy
                if optimizer.n_steps % (args.print_step *
                                        10) == 0 and hvd.rank() == 0:
                    model.plot_attention()
                start_time_step = time.time()
            # reset dev set
            dev_set.reset()
            # Save checkpoint and evaluate model per epoch
            duration_epoch = time.time() - start_time_epoch
            if hvd_rank == 0:
                logger.info('========== EPOCH:%d (%.2f min) ==========' %
                            (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()
                if hvd_rank == 0:
                    save_checkpoint(model,
                                    save_path,
                                    optimizer,
                                    optimizer.n_epochs,
                                    remove_old_checkpoints=not noam)
            else:
                start_time_eval = time.time()
                # dev
                metric_dev = eval_epoch([model], val_loader, recog_params,
                                        args, optimizer.n_epochs + 1)
                metric_dev = hvd.allreduce(
                    np2tensor(np.array([metric_dev], dtype=float),
                              hvd.local_rank()))

                loss_dev = metric_dev.item()
                if hvd_rank == 0:
                    logger.info('Loss : %.2f %%' % (loss_dev))
                optimizer.epoch(loss_dev)
                if hvd.rank() == 0:
                    save_checkpoint(model,
                                    save_path,
                                    optimizer,
                                    optimizer.n_epochs,
                                    remove_old_checkpoints=False)
                if not optimizer.is_best:
                    model, _ = load_checkpoint(
                        model, save_path + '/model.epoch-' +
                        str(optimizer.best_epochs))

                duration_eval = time.time() - start_time_eval
                if hvd_rank == 0:
                    logger.info('Evaluation time: %.2f min' %
                                (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break
            # Convert to fine-tuning stage
            if optimizer.n_epochs == args.convert_to_sgd_epoch:
                n_epochs = optimizer.n_epochs
                n_steps = optimizer.n_steps
                optimizer = set_optimizer(model, 'sgd', args.lr,
                                          args.weight_decay)
                optimizer = hvd.DistributedOptimizer(
                    optimizer,
                    named_parameters=model.named_parameters(),
                    compression=hvd.Compression.none,
                    backward_passes_per_step=batch_per_allreduce)

                hvd.broadcast_parameters(model.state_dict(), root_rank=0)
                hvd.broadcast_optimizer_state(optimizer, root_rank=0)
                optimizer = LRScheduler(
                    optimizer,
                    args.lr,
                    decay_type=args.lr_decay_type,
                    decay_start_epoch=args.lr_decay_start_epoch,
                    decay_rate=args.lr_decay_rate,
                    decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
                    early_stop_patient_n_epochs=args.
                    early_stop_patient_n_epochs,
                    warmup_start_lr=args.warmup_start_lr,
                    warmup_n_steps=args.warmup_n_steps,
                    model_size=args.d_model,
                    factor=args.lr_factor,
                    noam=noam)

                optimizer._epoch = n_epochs
                optimizer._step = n_steps
                if hvd_rank == 0:
                    logger.info('========== Convert to SGD ==========')

            if optimizer.n_epochs == args.n_epochs:
                break
            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    if hvd_rank == 0:
        logger.info('Total time: %.2f hour' % (duration_train / 3600))

    return save_path
示例#3
0
def main():

    args = parse()

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_lm_name(args)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'),
                        key='training',
                        stdout=args.stdout)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size * args.n_gpus,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    bptt=args.bptt,
                    backward=args.backward,
                    serialize=args.serialize)
        ]

    args.vocab = train_set.vocab

    # Model setting
    model = build_lm(args, save_path)

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else
            conf['optimizer'], conf['lr'], conf['weight_decay'])

        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(
            optimizer,
            conf['lr'],
            decay_type=conf['lr_decay_type'],
            decay_start_epoch=conf['lr_decay_start_epoch'],
            decay_rate=conf['lr_decay_rate'],
            decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'],
            early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'],
            warmup_start_lr=conf['warmup_start_lr'],
            warmup_n_steps=conf['warmup_n_steps'],
            model_size=conf['d_model'],
            factor=conf['lr_factor'],
            noam=conf['lm_type'] == 'transformer')

        # Restore the last saved model
        model, optimizer = load_checkpoint(model,
                                           args.resume,
                                           optimizer,
                                           resume=True)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if epoch == conf['convert_to_sgd_epoch']:
            n_epochs = optimizer.n_epochs
            n_steps = optimizer.n_steps
            optimizer = set_optimizer(model, 'sgd', args.lr,
                                      conf['weight_decay'])
            optimizer = LRScheduler(optimizer,
                                    args.lr,
                                    decay_type='always',
                                    decay_start_epoch=0,
                                    decay_rate=0.5)
            optimizer._epoch = n_epochs
            optimizer._step = n_steps
            logger.info('========== Convert to SGD ==========')
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Set optimizer
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(
            optimizer,
            args.lr,
            decay_type=args.lr_decay_type,
            decay_start_epoch=args.lr_decay_start_epoch,
            decay_rate=args.lr_decay_rate,
            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
            warmup_start_lr=args.warmup_start_lr,
            warmup_n_steps=args.warmup_n_steps,
            model_size=args.d_model,
            factor=args.lr_factor,
            noam=args.lm_type == 'transformer')

    # GPU setting
    if args.n_gpus >= 1:
        torch.backends.cudnn.benchmark = True
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
        model.cuda()

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_tokens = 0
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()
        accum_n_tokens += sum([len(y) for y in ys_train])
        optimizer.zero_grad()
        loss, hidden, reporter = model(ys_train, hidden, reporter)
        loss.backward()
        loss.detach()  # Trancate the graph
        if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens:
            if args.clip_grad_norm > 0:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.parameters(), args.clip_grad_norm)
                reporter.add_tensorboard_scalar('total_norm', total_norm)
            optimizer.step()
            optimizer.zero_grad()
            accum_n_tokens = 0
        loss_train = loss.item()
        del loss
        hidden = model.module.repackage_state(hidden)
        reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
        # NOTE: loss/acc/ppl are already added in the model
        reporter.step()

        if optimizer.n_steps % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)"
                % (optimizer.n_steps,
                   optimizer.n_epochs + train_set.epoch_detail, loss_train,
                   loss_dev, np.exp(loss_train), np.exp(loss_dev),
                   optimizer.lr, ys_train.shape[0], duration_step / 60))
            start_time_step = time.time()
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))

        # Save fugures of loss and accuracy
        if optimizer.n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            if args.lm_type == 'transformer':
                model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()  # lr decay
                reporter.epoch()  # plot

                # Save the model
                save_checkpoint(
                    model,
                    save_path,
                    optimizer,
                    optimizer.n_epochs,
                    remove_old_checkpoints=args.lm_type != 'transformer')
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                logger.info('PPL (%s, epoch:%d): %.2f' %
                            (dev_set.set, optimizer.n_epochs, ppl_dev))
                optimizer.epoch(ppl_dev)  # lr decay
                reporter.epoch(ppl_dev, name='perplexity')  # plot

                if optimizer.is_best:
                    # Save the model
                    save_checkpoint(
                        model,
                        save_path,
                        optimizer,
                        optimizer.n_epochs,
                        remove_old_checkpoints=args.lm_type != 'transformer')

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        logger.info(
                            'PPL (%s, epoch:%d): %.2f' %
                            (eval_set.set, optimizer.n_epochs, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg., epoch:%d): %.2f' %
                                    (optimizer.n_epochs,
                                     ppl_test_avg / len(eval_sets)))

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    n_epochs = optimizer.n_epochs
                    n_steps = optimizer.n_steps
                    optimizer = set_optimizer(model, 'sgd', args.lr,
                                              args.weight_decay)
                    optimizer = LRScheduler(optimizer,
                                            args.lr,
                                            decay_type='always',
                                            decay_start_epoch=0,
                                            decay_rate=0.5)
                    optimizer._epoch = n_epochs
                    optimizer._step = n_steps
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
示例#4
0
def main():

    args = parse()

    hvd.init()
    torch.cuda.set_device(hvd.local_rank())
    hvd_rank = hvd.rank()
    # Load a conf file
    if args.resume:
        conf = load_config(os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        n_customers=hvd.size(),
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size,
                      bptt=args.bptt,
                      n_customers=hvd.size(),
                      backward=args.backward,
                      serialize=args.serialize)

    eval_set = Dataset(corpus=args.corpus,
                              tsv_path=args.eval_set,
                              dict_path=args.dict,
                              nlsyms=args.nlsyms,
                              unit=args.unit,
                              wp_model=args.wp_model,
                              batch_size=args.batch_size,
                              bptt=args.bptt,
                              n_customers=hvd.size(),
                              backward=args.backward,
                              serialize=args.serialize)

    args.vocab = train_set.vocab

    train_loader = ChunkDataloader(train_set,
                                   batch_size=1,
                                   num_workers = 1,
                                   distributed=True,
                                   shuffle=False)

    eval_loader = ChunkDataloader(eval_set,
                                 batch_size=1,
                                 num_workers=1,
                                 distributed=True)




    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_lm_name(args)
        save_path = mkdir_join(args.model_save_dir, '_'.join(
            os.path.basename(args.train_set).split('.')[:-1]), dir_name)
        if hvd.rank() == 0:
            save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    if hvd_rank == 0:
    	logger = set_logger(os.path.join(save_path, 'train.log'),
                            key='training', stdout=args.stdout)
        # Set process name
    	logger.info('PID: %s' % os.getpid())
    	logger.info('USERNAME: %s' % os.uname()[1])
    	logger.info('NUMBER_DEVICES: %s' % hvd.size())
    setproctitle(args.job_name if args.job_name else dir_name)
    # Model setting
    model = build_lm(args, save_path)
    # GPU setting
    if args.n_gpus >= 1:
        torch.backends.cudnn.benchmark = True
        model.cuda()

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else conf['optimizer'],
                                  conf['lr'], conf['weight_decay'])

        # Restore the last saved model
        if hvd_rank == 0:
            model, optimizer = load_checkpoint(model, args.resume, optimizer, resume=True)
        #broadcast
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())

        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(optimizer, conf['lr'],
                                decay_type=conf['lr_decay_type'],
                                decay_start_epoch=conf['lr_decay_start_epoch'],
                                decay_rate=conf['lr_decay_rate'],
                                decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'],
                                early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'],
                                warmup_start_lr=conf['warmup_start_lr'],
                                warmup_n_steps=conf['warmup_n_steps'],
                                model_size=conf['d_model'],
                                factor=conf['lr_factor'],
                                noam=conf['lm_type'] == 'transformer')

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if epoch == conf['convert_to_sgd_epoch']:
            n_epochs = optimizer.n_epochs
            n_steps = optimizer.n_steps
            optimizer = set_optimizer(model, 'sgd', args.lr, conf['weight_decay'])
            optimizer = LRScheduler(optimizer, args.lr,
                                    decay_type='always',
                                    decay_start_epoch=0,
                                    decay_rate=0.5)
            optimizer._epoch = n_epochs
            optimizer._step = n_steps
            if hvd_rank == 0:
                logger.info('========== Convert to SGD ==========')
            #broadcast
            optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
            hvd.broadcast_parameters(model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    else:
        # Save the conf file as a yaml file
        if hvd_rank == 0:
            save_config(vars(args), os.path.join(save_path, 'conf.yml'))
            # Save the nlsyms, dictionar, and wp_model
            if args.nlsyms:
                shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
            shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
            if args.unit == 'wp':
                shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))
            for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
                logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            if hvd.rank() == 0:
                logger.info("%s %d" % (n, n_params))
        if hvd_rank == 0:
            logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000))
            logger.info(model)

        # Set optimizer
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay)
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(optimizer, args.lr,
                                decay_type=args.lr_decay_type,
                                decay_start_epoch=args.lr_decay_start_epoch,
                                decay_rate=args.lr_decay_rate,
                                decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
                                early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
                                warmup_start_lr=args.warmup_start_lr,
                                warmup_n_steps=args.warmup_n_steps,
                                model_size=args.d_model,
                                factor=args.lr_factor,
                                noam=args.lm_type == 'transformer')

    

    # Set reporter
    reporter = Reporter(save_path)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    data_size = len(train_set)
    accum_n_tokens = 0
    verbose = 1 if hvd_rank == 0 else 0
    while True:
        model.train()
        with tqdm(total=data_size/hvd.size(),
                desc='Train Epoch     #{}'.format(optimizer.n_epochs + 1),
                disable=not verbose) as pbar_epoch:
            # Compute loss in the training set
            for _, ys_train in enumerate(train_loader):
                accum_n_tokens += sum([len(y) for y in ys_train])
                optimizer.zero_grad()
                loss, hidden, reporter = model(ys_train, hidden, reporter)
                loss.backward()
                loss.detach()  # Trancate the graph
                if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens:
                    if args.clip_grad_norm > 0:
                        total_norm = torch.nn.utils.clip_grad_norm_(
                            model.parameters(), args.clip_grad_norm)
                        #reporter.add_tensorboard_scalar('total_norm', total_norm)
                    optimizer.step()
                    optimizer.zero_grad()
                    accum_n_tokens = 0
                loss_train = loss.item()
                del loss
                hidden = model.repackage_state(hidden)
                
                if optimizer.n_steps % args.print_step == 0:
                    model.eval()
                    # Compute loss in the dev set
                    ys_dev = dev_set.next()[0]
                    loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
                    loss_dev = loss.item()
                    del loss
                    
                    duration_step = time.time() - start_time_step
                    if hvd_rank == 0:
                    	logger.info("step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" %
                                    (optimizer.n_steps, optimizer.n_steps/data_size*hvd.size(),
                                    loss_train, loss_dev,
                                    np.exp(loss_train), np.exp(loss_dev),
                                    optimizer.lr, ys_train.shape[0], duration_step / 60))
                    start_time_step = time.time()
                
                pbar_epoch.update(1)
                

            # Save checkpoint and evaluate model per epoch
            duration_epoch = time.time() - start_time_epoch
            if hvd_rank == 0:
                logger.info('========== EPOCH:%d (%.2f min) ==========' %(optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:

                # Save the model
                if hvd_rank == 0:
                    optimizer.epoch()
                    save_checkpoint(model, save_path, optimizer, optimizer.n_epochs,
                                        remove_old_checkpoints=args.lm_type != 'transformer')
            else:
                start_time_eval = time.time()
                # dev
                model.eval()
                ppl_dev, _ = eval_ppl_parallel([model], eval_loader, optimizer.n_epochs, batch_size=args.batch_size)
                ppl_dev = hvd.allreduce(np2tensor(np.array([ppl_dev], dtype=float), hvd.local_rank()))
                
                if hvd_rank == 0:
                    logger.info('PPL : %.2f' %  ppl_dev)
                optimizer.epoch(ppl_dev)

                if optimizer.is_best and hvd.rank() == 0:
                    # Save the model
                    save_checkpoint(model, save_path, optimizer, optimizer.n_epochs,
                                    remove_old_checkpoints=args.lm_type != 'transformer')

                duration_eval = time.time() - start_time_eval

                if hvd_rank == 0:
                    logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:

                    n_epochs = optimizer.n_epochs
                    n_steps = optimizer.n_steps
                    optimizer = set_optimizer(model, 'sgd', args.lr, args.weight_decay)

                    optimizer = hvd.DistributedOptimizer(
                                    optimizer, named_parameters=model.named_parameters())
                    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
                    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
                    optimizer = LRScheduler(optimizer, args.lr,
                                                decay_type='always',
                                                decay_start_epoch=0,
                                                decay_rate=0.5)
                    optimizer._epoch = n_epochs
                    optimizer._step = n_steps
                    if hvd_rank == 0:
                        logger.info('========== Convert to SGD ==========')
                if optimizer.n_epochs == args.n_epochs:
                    break

                start_time_step = time.time()
                start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    if hvd_rank == 0:
        logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()

    return save_path