Example #1
0
def get_optim(cfg, model, dataset_iter_num):
    cfg = cfg.OPTIM
    optim_name = cfg.NAME
    optimizer = None
    assert optim_name in ["FusedLAMB", "AdamW", "Adam",
                          "SGD"], "optimizer not allowed"
    parameters = filter(lambda p: p.requires_grad, model.parameters())

    if optim_name == "FusedLAMB":
        optimizer = FusedLAMB(parameters, lr=cfg.INIT_LR, eps=cfg.ADAM_EPSILON)
    if optim_name == "AdamW":
        optimizer = AdamW(parameters, lr=cfg.INIT_LR, eps=cfg.ADAM_EPSILON)
    if optim_name == "Adam":
        optimizer = Adam(parameters, lr=cfg.INIT_LR, eps=cfg.ADAM_EPSILON)
    if optim_name == "SGD":
        optimizer = SGD(parameters, lr=cfg.INIT_LR, momentum=cfg.SGD_MOMENTUM)
    warmup_step = int(cfg.WARM_UP_EPOCH * dataset_iter_num)
    max_step = cfg.MAX_EPOCH * dataset_iter_num

    if cfg.USE_LR_SCHEDULER:
        if cfg.LR_SCHEDULER_TYPE == "get_exponent_schedule_with_warmup":
            scheduler = get_exponent_schedule_with_warmup(
                optimizer, warmup_step, exponent=cfg.EXPONENT)
        else:
            scheduler = globals()[cfg.LR_SCHEDULER_TYPE](optimizer,
                                                         warmup_step, max_step)
    else:
        scheduler = None

    return optimizer, scheduler
 def __init__(self, args, params):
     super().__init__(args)
     try:
         from apex.optimizers import FusedLAMB
         self._optimizer = FusedLAMB(params, **self.optimizer_config)
     except ImportError:
         raise ImportError('Please install apex to use LAMB optimizer')
Example #3
0
    def configure_optimizers(self):
        if self.optimizer == "adam":
            optimizer = torch.optim.AdamW(self.parameters(),
                                          lr=self.learning_rate,
                                          weight_decay=0.0)
        elif self.optimizer == "lamb":
            optimizer = FusedLAMB(
                self.parameters(),
                lr=self.learning_rate,
                weight_decay=0.0,
            )
        elif self.optimizer == "gremlin":
            from ..optim import GremlinAdam

            optimizer = GremlinAdam(
                [{
                    "params": self.parameters(),
                    "gremlin": True
                }],
                lr=self.learning_rate,
            )
        else:
            raise ValueError(f"Unrecognized optimizer {self.optimizer}")

        lr_scheduler = lr_schedulers.get(self.lr_scheduler)(
            optimizer, self.warmup_steps, self.trainer.max_steps)
        scheduler_dict = {
            "scheduler": lr_scheduler,
            "interval": "step",
        }
        return [optimizer], [scheduler_dict]
Example #4
0
 def configure_optimizers(self):
     if self.optimizer == "adam":
         optimizer = torch.optim.AdamW(
             self.parameters(), lr=self.learning_rate, weight_decay=self.l2_coeff
         )
     elif self.optimizer == "lamb":
         optimizer = FusedLAMB(
             self.parameters(),
             lr=self.learning_rate,
             weight_decay=self.l2_coeff,
         )
     else:
         raise ValueError(f"Unrecognized optimizer {self.optimizer}")
     return [optimizer]
def get_optimizer(optimizer_name: str, parameters, learning_rate: float, weight_decay=0.0, **kwargs):
    if optimizer_name.lower() == "sgd":
        return SGD(parameters, learning_rate, momentum=0.9, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "adam":
        return Adam(parameters, learning_rate, weight_decay=weight_decay, eps=1e-5, **kwargs)  # As Jeremy suggests

    if optimizer_name.lower() == "rms":
        return RMSprop(parameters, learning_rate, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "adamw":
        return AdamW(parameters, learning_rate, weight_decay=weight_decay, eps=1e-5, **kwargs)

    if optimizer_name.lower() == "radam":
        return RAdam(parameters, learning_rate, weight_decay=weight_decay, eps=1e-5, **kwargs)  # As Jeremy suggests

    if optimizer_name.lower() == "ranger":
        return Ranger(parameters, learning_rate, weight_decay=weight_decay, **kwargs)

    # if optimizer_name.lower() == "qhadamw":
    #     return QHAdamW(parameters, learning_rate, weight_decay=weight_decay,
    #                    **kwargs)
    #
    if optimizer_name.lower() == "lamb":
        return Lamb(parameters, learning_rate, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "fused_lamb":
        from apex.optimizers import FusedLAMB

        return FusedLAMB(parameters, learning_rate, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "fused_adam":
        from apex.optimizers import FusedAdam

        return FusedAdam(parameters, learning_rate, eps=1e-5, weight_decay=weight_decay, adam_w_mode=True, **kwargs)

    if optimizer_name.lower() == "fused_sgd":
        from apex.optimizers import FusedSGD

        return FusedSGD(parameters, learning_rate, weight_decay=weight_decay, momentum=0.9, **kwargs)

    if optimizer_name.lower() == "diffgrad":
        return DiffGrad(parameters, learning_rate, eps=1e-5, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "novograd":
        return Novograd(parameters, learning_rate, eps=1e-5, weight_decay=weight_decay, **kwargs)

    raise ValueError("Unsupported optimizer name " + optimizer_name)
Example #6
0
    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.hparams.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        if self.hparams.lamb:
            optimizer = FusedLAMB(optimizer_grouped_parameters,
                                  lr=self.hparams.learning_rate,
                                  eps=self.hparams.adam_epsilon)

        elif self.hparams.adafactor:
            optimizer = Adafactor(optimizer_grouped_parameters,
                                  lr=self.hparams.learning_rate,
                                  scale_parameter=False,
                                  relative_step=False)
        else:
            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=self.hparams.learning_rate,
                                  eps=self.hparams.adam_epsilon)
        self.opt = optimizer

        scheduler = self.get_lr_scheduler()

        return [optimizer], [scheduler]
def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if weight_decay and filter_bias_and_bn:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay()
        parameters = add_weight_decay(model, weight_decay, skip)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas
    if hasattr(args, 'opt_args') and args.opt_args is not None:
        opt_args.update(args.opt_args)

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamp':
        # ================================
        # optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)

        print(' ')
        print('Gradient centralization is enabled for AdamP optimizer.')
        print(' ')

        optimizer = AdamP(parameters,
                          wd_ratio=0.01,
                          nesterov=True,
                          use_gc=True,
                          gc_conv_only=True,
                          gc_loc=False,
                          **opt_args)
        # ================================
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         momentum=args.momentum,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        if not args.lr:
            opt_args['lr'] = None
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  alpha=0.9,
                                  momentum=args.momentum,
                                  **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              alpha=0.9,
                              momentum=args.momentum,
                              **opt_args)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters, **opt_args)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
Example #8
0
def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if 'adamw' in opt_lower or 'radam' in opt_lower:
        # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
        # I don't believe they follow the paper or original Torch7 impl which schedules weight
        # decay based on the ratio of current_lr/initial_lr
        weight_decay /= args.lr
    if weight_decay and filter_bias_and_bn:
        print("has weight decay and filter bias")
        parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.
    else:
        print("Comes here to unfrozen params inside optim")

        parameters = unfrozen_params(model)

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=True)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=False)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'adamw':
        optimizer = AdamW(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters,
                                   lr=args.lr,
                                   weight_decay=weight_decay,
                                   eps=args.opt_eps)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  lr=args.lr,
                                  alpha=0.9,
                                  eps=args.opt_eps,
                                  momentum=args.momentum,
                                  weight_decay=weight_decay)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              lr=args.lr,
                              alpha=0.9,
                              eps=args.opt_eps,
                              momentum=args.momentum,
                              weight_decay=weight_decay)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters,
                             lr=args.lr,
                             weight_decay=weight_decay,
                             eps=args.opt_eps)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=True)
    elif opt_lower == 'fusedmomentum':
        print("my optimizer")
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=False)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=False,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=True,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters,
                              lr=args.lr,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusednovograd':
        optimizer = FusedNovoGrad(parameters,
                                  lr=args.lr,
                                  betas=(0.95, 0.98),
                                  weight_decay=weight_decay,
                                  eps=args.opt_eps)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
Example #9
0
def train(args, teacher_args):
    """Train FCL-taco2 model."""
    set_deterministic_pytorch(args)
    # args.use_fe_condition = True
    # # pre-occupy GPU
    # buff = torch.randn(int(1e9)).cuda()
    # del buff
    
    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    odim = int(valid_json[utts[0]]["input"][0]["shape"][1])
    logging.info("#input dims: " + str(idim))
    logging.info("#output dims: " + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to" + model_conf)
        f.write(
            json.dumps(
                (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
            ).encode("utf_8")
        )
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    # specify model architecture
    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args, TTSInterface)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args, args, teacher_args=teacher_args)
    
    #print('\n\nteacher_args:', teacher_args.embed_dim, '\n\n')
    teacher_model_class = dynamic_import(teacher_args.model_module)
    teacher_model = teacher_model_class(idim, odim, teacher_args, teacher_args)
    #teacher_model = teacher_model.to('cuda')
    if teacher_args.amp_checkpoint is None:
        raise ValueError('please provide the teacher-model-amp-checkpoint')
    else:
        logging.info("teacher-model resumed from %s" % teacher_args.amp_checkpoint)
        teacher_checkpoint = torch.load(teacher_args.amp_checkpoint)
        teacher_model.load_state_dict(teacher_checkpoint['model'])

    # print('tts_wds:', model.base_plot_keys)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        # model = torch.nn.DataParallel(model, device_ids=[4,5,6,7])
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)"
                % (args.batch_size, args.batch_size * args.ngpu)
            )
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)
    teacher_model = teacher_model.to(device)
    for param in teacher_model.parameters(): # fix teacher model params
        param.requires_grad = False

    # freeze modules, if specified
    if args.freeze_mods:
        if hasattr(model, "module"):
            freeze_mods = ["module." + x for x in args.freeze_mods]
        else:
            freeze_mods = args.freeze_mods

        for mod, param in model.named_parameters():
            if any(mod.startswith(key) for key in freeze_mods):
                logging.info(f"{mod} is frozen not to be updated.")
                param.requires_grad = False

        model_params = filter(lambda x: x.requires_grad, model.parameters())
    else:
        model_params = model.parameters()

    # Setup an optimizer
    if args.opt == "adam":
        optimizer = torch.optim.Adam(
            model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay
        )
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr
        )
    elif args.opt == 'lamb':
        kw = dict(lr=0.1, betas=(0.9, 0.98), eps=1e-9,
              weight_decay=1e-6)
        from apex.optimizers import FusedAdam, FusedLAMB
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)
    
    if args.use_amp:
        opt_level = 'O1'
        model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    
    if args.amp_checkpoint is not None:
        logging.info("resumed from %s" % args.amp_checkpoint)
        checkpoint = torch.load(args.amp_checkpoint)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        amp.load_state_dict(checkpoint['amp'])
        
    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    
    num_batches = len(train_json.keys()) // args.batch_size
    
    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
        
    print(f'\n\n batch_sort_key: {args.batch_sort_key} \n\n')
    
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )
    
    
    from io_utils_fcl import LoadInputsAndTargets
    
    load_tr = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
        pad_eos=args.pad_eos,
    )

    load_cv = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
        pad_eos=args.pad_eos,
    )

    converter = CustomConverter(reduction_factor=args.reduction_factor,
                                use_fe_condition=args.use_fe_condition,
                                append_position=args.append_position,
                                )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    train_iter = {
        "main": ChainerDataLoader(
            dataset=TransformDataset(
                train_batchset, lambda data: converter([load_tr(data)])
            ),
            batch_size=1,
            num_workers=args.num_iter_processes,
            shuffle=not use_sortagrad,
            collate_fn=lambda x: x[0],
        )
    }
    valid_iter = {
        "main": ChainerDataLoader(
            dataset=TransformDataset(
                valid_batchset, lambda data: converter([load_cv(data)])
            ),
            batch_size=1,
            shuffle=False,
            collate_fn=lambda x: x[0],
            num_workers=args.num_iter_processes,
        )
    }

    # Set up a trainer
    updater = CustomUpdater(
        teacher_model, model, args.grad_clip, train_iter, optimizer, device, args.accum_grad, args.use_amp, num_batches, args.outdir
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # set intervals
    eval_interval = (args.eval_interval_epochs, "epoch")
    save_interval = (args.save_interval_epochs, "epoch")
    report_interval = (args.report_interval_iters, "iteration")

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(teacher_model, model, valid_iter, reporter, device), trigger=eval_interval
    )

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger(
            "validation/main/loss", trigger=eval_interval
        ),
    )


    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ["main/" + key, "validation/main/" + key]
        trainer.extend(
            extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"),
            trigger=eval_interval,
        )
        plot_keys += plot_key
    trainer.extend(
        extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"),
        trigger=eval_interval,
    )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    # if args.tensorboard_dir is not None and args.tensorboard_dir != "":
    #     writer = SummaryWriter(args.tensorboard_dir)
    #     trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
        )

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Example #10
0
def create_optimizer_v2(model_or_params,
                        opt: str = 'sgd',
                        lr: Optional[float] = None,
                        weight_decay: float = 0.,
                        momentum: float = 0.9,
                        filter_bias_and_bn: bool = True,
                        layer_decay: Optional[float] = None,
                        param_group_fn: Optional[Callable] = None,
                        **kwargs):
    """ Create an optimizer.

    TODO currently the model is passed in and all parameters are selected for optimization.
    For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
      * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
      * expose the parameters interface and leave it up to caller

    Args:
        model_or_params (nn.Module): model containing parameters to optimize
        opt: name of optimizer to create
        lr: initial learning rate
        weight_decay: weight decay to apply in optimizer
        momentum:  momentum for momentum based optimizers (others may use betas via kwargs)
        filter_bias_and_bn:  filter out bias, bn and other 1d params from weight decay
        **kwargs: extra optimizer specific kwargs to pass through

    Returns:
        Optimizer
    """
    if isinstance(model_or_params, nn.Module):
        # a model was passed in, extract parameters and add weight decays to appropriate layers
        no_weight_decay = {}
        if hasattr(model_or_params, 'no_weight_decay'):
            no_weight_decay = model_or_params.no_weight_decay()

        if param_group_fn:
            parameters = param_group_fn(model_or_params)
        elif layer_decay is not None:
            parameters = param_groups_layer_decay(
                model_or_params,
                weight_decay=weight_decay,
                layer_decay=layer_decay,
                no_weight_decay_list=no_weight_decay)
            weight_decay = 0.
        elif weight_decay and filter_bias_and_bn:
            parameters = param_groups_weight_decay(model_or_params,
                                                   weight_decay,
                                                   no_weight_decay)
            weight_decay = 0.
        else:
            parameters = model_or_params.parameters()
    else:
        # iterable of parameters or param groups passed in
        parameters = model_or_params

    opt_lower = opt.lower()
    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(weight_decay=weight_decay, **kwargs)
    if lr is not None:
        opt_args.setdefault('lr', lr)

    # basic SGD & related
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         momentum=momentum,
                         nesterov=True,
                         **opt_args)

    # adaptive
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
    elif opt_lower == 'nadam':
        try:
            # NOTE PyTorch >= 1.10 should have native NAdam
            optimizer = optim.Nadam(parameters, **opt_args)
        except AttributeError:
            optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamax':
        optimizer = optim.Adamax(parameters, **opt_args)
    elif opt_lower == 'adabelief':
        optimizer = AdaBelief(parameters, rectify=False, **opt_args)
    elif opt_lower == 'radabelief':
        optimizer = AdaBelief(parameters, rectify=True, **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adagrad':
        opt_args.setdefault('eps', 1e-8)
        optimizer = optim.Adagrad(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'lamb':
        optimizer = Lamb(parameters, **opt_args)
    elif opt_lower == 'lambc':
        optimizer = Lamb(parameters, trust_clip=True, **opt_args)
    elif opt_lower == 'larc':
        optimizer = Lars(parameters,
                         momentum=momentum,
                         trust_clip=True,
                         **opt_args)
    elif opt_lower == 'lars':
        optimizer = Lars(parameters, momentum=momentum, **opt_args)
    elif opt_lower == 'nlarc':
        optimizer = Lars(parameters,
                         momentum=momentum,
                         trust_clip=True,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'nlars':
        optimizer = Lars(parameters,
                         momentum=momentum,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'madgrad':
        optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
    elif opt_lower == 'madgradw':
        optimizer = MADGRAD(parameters,
                            momentum=momentum,
                            decoupled_decay=True,
                            **opt_args)
    elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  alpha=0.9,
                                  momentum=momentum,
                                  **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              alpha=0.9,
                              momentum=momentum,
                              **opt_args)

    # second order
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)

    # NVIDIA fused optimizers, require APEX to be installed
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)

    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
Example #11
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    distributed_run = args.world_size > 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)

    if args.local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)

    log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
    tb_subsets = ['train', 'val']
    if args.ema_decay > 0.0:
        tb_subsets.append('val_ema')

    logger.init(log_fpath,
                args.output,
                enabled=(args.local_rank == 0),
                tb_subsets=tb_subsets)
    logger.parameters(vars(args), tb_subset='train')

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, args.world_size, args.local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    # Store pitch mean/std as params to translate from Hz during inference
    with open(args.pitch_mean_std_file, 'r') as f:
        stats = json.load(f)
    model.pitch_mean[0] = stats['mean']
    model.pitch_std[0] = stats['std']

    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    #if args.amp:
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    start_epoch = [1]
    start_iter = [0]

    assert args.checkpoint_path is None or args.resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(args.local_rank, model, ema_model, optimizer,
                        start_epoch, start_iter, model_config, args.amp,
                        ch_fpath, args.world_size)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    criterion = loss_functions.get_loss_function(
        'FastPitch',
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)

    collate_fn = data_functions.get_collate_function('FastPitch')
    trainset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                              args.training_files, args)
    valset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                            args.validation_files, args)
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(trainset,
                              num_workers=16,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')

    model.train()

    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.perf_counter()

        epoch_loss = 0.0
        epoch_mel_loss = 0.0
        epoch_num_frames = 0
        epoch_frames_per_sec = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}

        epoch_iter = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps
        for batch in train_loader:

            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                iter_start_time = time.perf_counter()

                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)

                model.zero_grad()

            x, y, num_frames = batch_to_gpu(batch)

            #AMP upstream autocast
            with torch.cuda.amp.autocast(enabled=args.amp):
                y_pred = model(x, use_gt_durations=True)
                loss, meta = criterion(y_pred, y)

                loss /= args.gradient_accumulation_steps
            meta = {
                k: v / args.gradient_accumulation_steps
                for k, v in meta.items()
            }

            if args.amp:
                #with amp.scale_loss(loss, optimizer) as scaled_loss:
                #scaled_loss.backward()
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {
                    k: reduce_tensor(v, args.world_size)
                    for k, v in meta.items()
                }
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.gradient_accumulation_steps == 0:

                logger.log_grads_tb(total_iter, model)
                if args.amp:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)
                    scaler.step(optimizer)
                    scaler.update()
                    #optimizer.zero_grad(set_to_none=True)
                    optimizer.zero_grad()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)

                    optimizer.step()
                apply_ema_decay(model, ema_model, args.ema_decay)

                iter_time = time.perf_counter() - iter_start_time
                iter_mel_loss = iter_meta['mel_loss'].item()
                epoch_frames_per_sec += iter_num_frames / iter_time
                epoch_loss += iter_loss
                epoch_num_frames += iter_num_frames
                epoch_mel_loss += iter_mel_loss

                logger.log(
                    (epoch, epoch_iter, num_iters),
                    tb_total_steps=total_iter,
                    subset='train',
                    data=OrderedDict([
                        ('loss', iter_loss), ('mel_loss', iter_mel_loss),
                        ('frames/s', iter_num_frames / iter_time),
                        ('took', iter_time),
                        ('lrate', optimizer.param_groups[0]['lr'])
                    ]),
                )

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}

        # Finished epoch
        epoch_time = time.perf_counter() - epoch_start_time

        logger.log(
            (epoch, ),
            tb_total_steps=None,
            subset='train_avg',
            data=OrderedDict([('loss', epoch_loss / epoch_iter),
                              ('mel_loss', epoch_mel_loss / epoch_iter),
                              ('frames/s', epoch_num_frames / epoch_time),
                              ('took', epoch_time)]),
        )

        validate(model,
                 epoch,
                 total_iter,
                 criterion,
                 valset,
                 args.batch_size,
                 collate_fn,
                 distributed_run,
                 batch_to_gpu,
                 use_gt_durations=True)

        if args.ema_decay > 0:
            validate(ema_model,
                     epoch,
                     total_iter,
                     criterion,
                     valset,
                     args.batch_size,
                     collate_fn,
                     distributed_run,
                     batch_to_gpu,
                     use_gt_durations=True,
                     ema=True)

        if (epoch > 0 and args.epochs_per_checkpoint > 0
                and (epoch % args.epochs_per_checkpoint == 0)
                and args.local_rank == 0):

            checkpoint_path = os.path.join(args.output,
                                           f"FastPitch_checkpoint_{epoch}.pt")
            save_checkpoint(args.local_rank, model, ema_model, optimizer,
                            scaler, epoch, total_iter, model_config, args.amp,
                            checkpoint_path)
        logger.flush()

    # Finished training
    logger.log(
        (),
        tb_total_steps=None,
        subset='train_avg',
        data=OrderedDict([('loss', epoch_loss / epoch_iter),
                          ('mel_loss', epoch_mel_loss / epoch_iter),
                          ('frames/s', epoch_num_frames / epoch_time),
                          ('took', epoch_time)]),
    )
    validate(model,
             None,
             total_iter,
             criterion,
             valset,
             args.batch_size,
             collate_fn,
             distributed_run,
             batch_to_gpu,
             use_gt_durations=True)

    if (epoch > 0 and args.epochs_per_checkpoint > 0
            and (epoch % args.epochs_per_checkpoint != 0)
            and args.local_rank == 0):
        checkpoint_path = os.path.join(args.output,
                                       f"FastPitch_checkpoint_{epoch}.pt")
        save_checkpoint(args.local_rank, model, ema_model, optimizer, scaler,
                        epoch, total_iter, model_config, args.amp,
                        checkpoint_path)
Example #12
0
def run(config, args):

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size
    distributed_run = world_size > 1

    torch.manual_seed(args.seed + local_rank)
    np.random.seed(args.seed + local_rank)

    #    if local_rank == 0:
    #        if not os.path.exists(args.output):
    #            os.makedirs(args.output)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False

    if distributed_run:
        init_distributed(args, world_size, local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')

    if local_rank == 0:
        print("start training")
        print("args", args)
        print("config", config)

    #############################################
    # model
    if local_rank == 0:
        print("load model")
    model = WaveGrad(config).cuda()

    my_schedule = model.set_new_noise_schedule
    compute_loss = model.compute_loss
    # if local_rank == 0:
    #     print(model)

    # optimizer amp config
    if local_rank == 0:
        print("configure optimizer and amp")
    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)

    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    elif args.optimizer == 'pytorch':
        optimizer = torch.optim.Adam(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)
    start_epoch = [1]
    start_iter = [0]

    ################
    #load checkpoint
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
        load_checkpoint(local_rank, model, optimizer, start_epoch, start_iter,
                        config, args.amp, ch_fpath, world_size)

    iteration = epoch * args.rank * args.batch_size
    if local_rank == 0:
        if (epoch % args.epochs_per_checkpoint == 0):
            ch_path = os.path.join(args.output,
                                   "WaveGrad_ch_{:d}.pt".format(epoch))
            save_checkpoint(local_rank, model, optimizer, epoch, iteration,
                            config, args.amp, ch_path)

            ch_path = os.path.join(args.output,
                                   "WaveGrad_model_ch_{:d}.pt".format(epoch))
            save_checkpoint_modelonly(local_rank, model, epoch, iteration,
                                      config, ch_path)
def create_optimizer_v2(
        model: nn.Module,
        optimizer_name: str = 'sgd',
        learning_rate: Optional[float] = None,
        weight_decay: float = 0.,
        momentum: float = 0.9,
        filter_bias_and_bn: bool = True,
        **kwargs):
    """ Create an optimizer.

    TODO currently the model is passed in and all parameters are selected for optimization.
    For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
      * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
      * expose the parameters interface and leave it up to caller

    Args:
        model (nn.Module): model containing parameters to optimize
        optimizer_name: name of optimizer to create
        learning_rate: initial learning rate
        weight_decay: weight decay to apply in optimizer
        momentum:  momentum for momentum based optimizers (others may use betas via kwargs)
        filter_bias_and_bn:  filter out bias, bn and other 1d params from weight decay
        **kwargs: extra optimizer specific kwargs to pass through

    Returns:
        Optimizer
    """
    opt_lower = optimizer_name.lower()
    if weight_decay and filter_bias_and_bn:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay()
        parameters = add_weight_decay(model, weight_decay, skip)
        weight_decay = 0.
    else:
        parameters = model.parameters()
    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs)
    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args) 
    elif opt_lower == 'adabelief':
        optimizer = AdaBelief(parameters, rectify = False, print_change_log = False,**opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamp':        
        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        if not learning_rate:
            opt_args['lr'] = None
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters, **opt_args)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)
    model = BertForPreTraining(config)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(
                args.output_dir, "ckpt_{}.pt".format(global_step)),
                                    map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)
        if args.phase2:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)
    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale="dynamic")
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale=args.loss_scale)
        amp._amp_state.loss_scalers[0]._loss_scale = 2**20

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(
                    checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter][
                    'step'] = global_step
                checkpoint['optimizer']['param_groups'][iter][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            model = DDP(
                model,
                message_size=250000000,
                gradient_predivide_factor=torch.distributed.get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()],
                           torch.distributed.broadcast, (0, ))
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    return model, optimizer, lr_scheduler, checkpoint, global_step
def create_optimizer_param(args, parameters):
    opt_lower = args.opt.lower()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=args.lr, weight_decay=args.weight_decay)
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas
    if hasattr(args, 'opt_args') and args.opt_args is not None:
        opt_args.update(args.opt_args)

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         momentum=args.momentum,
                         nesterov=True,
                         **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        if not args.lr:
            opt_args['lr'] = None
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  alpha=0.9,
                                  momentum=args.momentum,
                                  **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              alpha=0.9,
                              momentum=args.momentum,
                              **opt_args)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters, **opt_args)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
Example #16
0
def train(model: nn.Module, loss_fn: _Loss, train_dataloader: DataLoader,
          val_dataloader: DataLoader, callbacks: List[BaseCallback],
          logger: Logger, args):
    device = torch.cuda.current_device()
    model.to(device=device)
    local_rank = get_local_rank()
    world_size = dist.get_world_size() if dist.is_initialized() else 1

    if dist.is_initialized():
        model = DistributedDataParallel(model,
                                        device_ids=[local_rank],
                                        output_device=local_rank)
        model._set_static_graph()

    model.train()
    grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(),
                              lr=args.learning_rate,
                              betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(),
                              lr=args.learning_rate,
                              betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    epoch_start = load_state(model, optimizer, args.load_ckpt_path,
                             callbacks) if args.load_ckpt_path else 0

    for callback in callbacks:
        callback.on_fit_start(optimizer, args)

    for epoch_idx in range(epoch_start, args.epochs):
        if isinstance(train_dataloader.sampler, DistributedSampler):
            train_dataloader.sampler.set_epoch(epoch_idx)

        loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx,
                           grad_scaler, optimizer, local_rank, callbacks, args)
        if dist.is_initialized():
            loss = torch.tensor(loss, dtype=torch.float, device=device)
            torch.distributed.all_reduce(loss)
            loss = (loss / world_size).item()

        logging.info(f'Train loss: {loss}')
        logger.log_metrics({'train loss': loss}, epoch_idx)

        for callback in callbacks:
            callback.on_epoch_end()

        if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
                and (epoch_idx + 1) % args.ckpt_interval == 0:
            save_state(model, optimizer, epoch_idx, args.save_ckpt_path,
                       callbacks)

        if not args.benchmark and ((args.eval_interval > 0 and
                                    (epoch_idx + 1) % args.eval_interval == 0)
                                   or epoch_idx + 1 == args.epochs):
            evaluate(model, val_dataloader, callbacks, args)
            model.train()

            for callback in callbacks:
                callback.on_validation_end(epoch_idx)

    if args.save_ckpt_path is not None and not args.benchmark:
        save_state(model, optimizer, args.epochs, args.save_ckpt_path,
                   callbacks)

    for callback in callbacks:
        callback.on_fit_end()
Example #17
0
def main():
    args = parse_args()

    assert (torch.cuda.is_available())
    assert args.prediction_frequency % args.log_frequency == 0

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    random.seed(args.seed + args.local_rank)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    assert args.grad_accumulation >= 1
    batch_size = args.gpu_batch_size

    print_once('Setting up datasets...')
    train_dataset_kw, train_features_kw = config.input(cfg, 'train')
    val_dataset_kw, val_features_kw = config.input(cfg, 'val')

    use_dali = args.dali_device in ('cpu', 'gpu')
    if use_dali:
        assert train_dataset_kw['ignore_offline_speed_perturbation'], \
            "DALI doesn't support offline speed perturbation"

        # pad_to_max_duration is not supported by DALI - have simple padders
        if train_features_kw['pad_to_max_duration']:
            train_feat_proc = BaseFeatures(
                pad_align=train_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=train_features_kw['max_duration'],
                sample_rate=train_features_kw['sample_rate'],
                window_size=train_features_kw['window_size'],
                window_stride=train_features_kw['window_stride'])
            train_features_kw['pad_to_max_duration'] = False
        else:
            train_feat_proc = None

        if val_features_kw['pad_to_max_duration']:
            val_feat_proc = BaseFeatures(
                pad_align=val_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=val_features_kw['max_duration'],
                sample_rate=val_features_kw['sample_rate'],
                window_size=val_features_kw['window_size'],
                window_stride=val_features_kw['window_stride'])
            val_features_kw['pad_to_max_duration'] = False
        else:
            val_feat_proc = None

        train_loader = DaliDataLoader(
            gpu_id=args.local_rank,
            dataset_path=args.dataset_dir,
            config_data=train_dataset_kw,
            config_features=train_features_kw,
            json_names=args.train_manifests,
            batch_size=batch_size,
            grad_accumulation_steps=args.grad_accumulation,
            pipeline_type="train",
            device_type=args.dali_device,
            symbols=symbols)

        val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=batch_size,
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    symbols=symbols)
    else:
        train_dataset_kw, train_features_kw = config.input(cfg, 'train')
        train_dataset = AudioDataset(args.dataset_dir, args.train_manifests,
                                     symbols, **train_dataset_kw)
        train_loader = get_data_loader(train_dataset,
                                       batch_size,
                                       multi_gpu=multi_gpu,
                                       shuffle=True,
                                       num_workers=4)
        train_feat_proc = FilterbankFeatures(**train_features_kw)

        val_dataset_kw, val_features_kw = config.input(cfg, 'val')
        val_dataset = AudioDataset(args.dataset_dir, args.val_manifests,
                                   symbols, **val_dataset_kw)
        val_loader = get_data_loader(val_dataset,
                                     batch_size,
                                     multi_gpu=multi_gpu,
                                     shuffle=False,
                                     num_workers=4,
                                     drop_last=False)
        val_feat_proc = FilterbankFeatures(**val_features_kw)

        dur = train_dataset.duration / 3600
        dur_f = train_dataset.duration_filtered / 3600
        nsampl = len(train_dataset)
        print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
                   f'filtered {dur_f:.1f}h)')

    if train_feat_proc is not None:
        train_feat_proc.cuda()
    if val_feat_proc is not None:
        val_feat_proc.cuda()

    steps_per_epoch = len(train_loader) // args.grad_accumulation

    # set up the model
    model = QuartzNet(encoder_kw=config.encoder(cfg),
                      decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
    model.cuda()
    ctc_loss = CTCLossNM(n_classes=len(symbols))
    greedy_decoder = GreedyCTCDecoder()

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    # optimization
    kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
    if args.optimizer == "novograd":
        optimizer = Novograd(model.parameters(), **kw)
    elif args.optimizer == "adamw":
        optimizer = AdamW(model.parameters(), **kw)
    elif args.optimizer == 'lamb98':
        optimizer = FusedLAMB(model.parameters(),
                              betas=(0.9, 0.98),
                              eps=1e-9,
                              **kw)
    elif args.optimizer == 'fused_novograd':
        optimizer = FusedNovoGrad(model.parameters(),
                                  betas=(0.95, 0),
                                  bias_correction=False,
                                  reg_inside_moment=True,
                                  grad_averaging=False,
                                  **kw)
    else:
        raise ValueError(f'Invalid optimizer "{args.optimizer}"')

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    adjust_lr = lambda step, epoch, optimizer: lr_policy(
        step,
        epoch,
        args.lr,
        optimizer,
        steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs,
        hold_epochs=args.hold_epochs,
        num_epochs=args.epochs,
        policy=args.lr_policy,
        min_lr=args.min_lr,
        exp_gamma=args.lr_exp_gamma)

    if args.ema > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if multi_gpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'QuartzNet',
                                args.keep_milestones)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    # training loop
    model.train()
    if args.ema > 0.0:
        mt_ema_params = init_multi_tensor_ema(model, ema_model)
    # ema_model_weight_list, model_weight_list, overflow_buf_for_ema = ema_

    # pre-allocate
    if args.pre_allocate_range is not None:
        n_feats = train_features_kw['n_filt']
        pad_align = train_features_kw['pad_align']
        a, b = args.pre_allocate_range
        for n_frames in range(a, b + pad_align, pad_align):
            print_once(
                f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')

            feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
            feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
            txt = torch.randint(high=len(symbols) - 1,
                                size=(batch_size, 100),
                                device='cuda')
            txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
            with torch.cuda.amp.autocast(enabled=args.amp):
                log_probs, enc_lens = model(feat, feat_lens)
                del feat
                loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            loss.backward()
            model.zero_grad()
    torch.cuda.empty_cache()

    bmark_stats = BenchmarkStats()

    for epoch in range(start_epoch + 1, args.epochs + 1):
        if multi_gpu and not use_dali:
            train_loader.sampler.set_epoch(epoch)

        epoch_utts = 0
        epoch_loss = 0
        accumulated_batches = 0
        epoch_start_time = time.time()
        epoch_eval_time = 0

        for batch in train_loader:

            if accumulated_batches == 0:
                step_loss = 0
                step_utts = 0
                step_start_time = time.time()

            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if train_feat_proc is not None:
                    feat, feat_lens = train_feat_proc(feat, feat_lens)
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feat, feat_lens = train_feat_proc(audio, audio_lens)

            # Use context manager to prevent redundant accumulation of gradients
            if (multi_gpu
                    and accumulated_batches + 1 < args.grad_accumulation):
                ctx = model.no_sync()
            else:
                ctx = empty_context()

            with ctx:
                with torch.cuda.amp.autocast(enabled=args.amp):
                    log_probs, enc_lens = model(feat, feat_lens)

                    loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
                    loss /= args.grad_accumulation

                if multi_gpu:
                    reduced_loss = reduce_tensor(loss.data, world_size)
                else:
                    reduced_loss = loss

                if torch.isnan(reduced_loss).any():
                    print_once(f'WARNING: loss is NaN; skipping update')
                    continue
                else:
                    step_loss += reduced_loss.item()
                    step_utts += batch[0].size(0) * world_size
                    epoch_utts += batch[0].size(0) * world_size
                    accumulated_batches += 1

                    scaler.scale(loss).backward()

            if accumulated_batches % args.grad_accumulation == 0:
                epoch_loss += step_loss
                scaler.step(optimizer)
                scaler.update()

                adjust_lr(step, epoch, optimizer)
                optimizer.zero_grad()

                if args.ema > 0.0:
                    apply_multi_tensor_ema(args.ema, *mt_ema_params)

                if step % args.log_frequency == 0:
                    preds = greedy_decoder(log_probs)
                    wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens,
                                                    symbols)

                    if step % args.prediction_frequency == 0:
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')

                    step_time = time.time() - step_start_time
                    log(
                        (epoch, step % steps_per_epoch
                         or steps_per_epoch, steps_per_epoch), step, 'train', {
                             'loss': step_loss,
                             'wer': 100.0 * wer,
                             'throughput': step_utts / step_time,
                             'took': step_time,
                             'lrate': optimizer.param_groups[0]['lr']
                         })

                step_start_time = time.time()

                if step % args.eval_frequency == 0:
                    tik = time.time()
                    wer = evaluate(epoch, step, val_loader, val_feat_proc,
                                   symbols, model, ema_model, ctc_loss,
                                   greedy_decoder, args.amp, use_dali)

                    if wer < best_wer and epoch >= args.save_best_from:
                        checkpointer.save(model,
                                          ema_model,
                                          optimizer,
                                          scaler,
                                          epoch,
                                          step,
                                          best_wer,
                                          is_best=True)
                        best_wer = wer
                    epoch_eval_time += time.time() - tik

                step += 1
                accumulated_batches = 0
                # end of step

            # DALI iterator need to be exhausted;
            # if not using DALI, simulate drop_last=True with grad accumulation
            if not use_dali and step > steps_per_epoch * epoch:
                break

        epoch_time = time.time() - epoch_start_time
        epoch_loss /= steps_per_epoch
        log(
            (epoch, ), None, 'train_avg', {
                'throughput': epoch_utts / epoch_time,
                'took': epoch_time,
                'loss': epoch_loss
            })
        bmark_stats.update(epoch_utts, epoch_time, epoch_loss)

        if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
            checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                              best_wer)

        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
                 ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)

        checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                          best_wer)
    flush_log()
Example #18
0
        loaded = load_checkpoint('model_weights_best_epoch*pt')
        if not loaded:
            # if best doesn't exist, take the latest
            loaded = load_checkpoint('model_weights_epoch*pt')

    model = Bruno(config)
    if config.from_snapshot is not None:
        state_dicts = torch.load(config.from_snapshot)
        model.load_state_dict(state_dicts['model'])

        logger.info(f'Model ckpt {config.from_snapshot} loaded.')

    model = model.to(device)

    # opt = torch.optim.Adam(model.parameters(), lr=config.lr)
    opt = FusedLAMB(model.parameters(), lr=config.lr)

    if config.from_snapshot is not None:
        state_dicts = torch.load(config.from_snapshot)
        opt.load_state_dict(state_dicts['opt'])

    if config.lr_policy == 'exp' or config.lr_policy is None:
        lr = torch.optim.lr_scheduler.ExponentialLR(opt, config.lr_decay)
    elif config.lr_policy == 'cyclic':
        lr = torch.optim.lr_scheduler.CyclicLR(
            opt,
            0,
            config.lr,
            step_size_up=steps_per_epoch * 2,
            scale_fn=partial(scale_fn, decay=config.lr_decay),
            cycle_momentum=False)
Example #19
0
def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if weight_decay and filter_bias_and_bn:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay
        parameters = add_weight_decay(model, weight_decay, skip)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)
    if hasattr(args,
               'opt_eps') and args.opt_eps is not None and opt_lower not in [
                   'sgd', 'momentum', 'fusedmomentum', 'fusedsgd'
               ]:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas

    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=True,
                              **opt_args)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              momentum=args.momentum,
                              nesterov=False,
                              **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=True,
                             **opt_args)
    elif opt_lower == 'fusedmomentum':
        optimizer = FusedSGD(parameters,
                             momentum=args.momentum,
                             nesterov=False,
                             **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size
    distributed_run = world_size > 1

    torch.manual_seed(args.seed + local_rank)
    np.random.seed(args.seed + local_rank)

    if local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)

        log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
        log_fpath = unique_dllogger_fpath(log_fpath)
        init_dllogger(log_fpath)
    else:
        init_dllogger(dummy=True)

    [DLLogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, world_size, local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    # Store pitch mean/std as params to translate from Hz during inference
    fpath = common.utils.stats_filename(args.dataset_path, args.training_files,
                                        'pitch_char')
    with open(args.pitch_mean_std_file, 'r') as f:
        stats = json.load(f)
    model.pitch_mean[0] = stats['mean']
    model.pitch_std[0] = stats['std']

    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    start_epoch = [1]
    start_iter = [0]

    assert args.checkpoint_path is None or args.resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch,
                        start_iter, model_config, args.amp, ch_fpath,
                        world_size)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    criterion = loss_functions.get_loss_function(
        'FastPitch',
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)

    collate_fn = data_functions.get_collate_function('FastPitch')
    trainset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                              args.training_files, args)
    valset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                            args.validation_files, args)
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(trainset,
                              num_workers=16,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')

    model.train()

    train_tblogger = TBLogger(local_rank, args.output, 'train')
    val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True)
    if args.ema_decay > 0:
        val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema')

    val_loss = 0.0
    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.time()

        epoch_loss = 0.0
        epoch_mel_loss = 0.0
        epoch_num_frames = 0
        epoch_frames_per_sec = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}

        epoch_iter = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps
        for batch in train_loader:
            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                iter_start_time = time.time()
                start = time.perf_counter()

                old_lr = optimizer.param_groups[0]['lr']
                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)
                new_lr = optimizer.param_groups[0]['lr']

                if new_lr != old_lr:
                    dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}'
                    train_tblogger.log_value(total_iter, 'lrate', new_lr)
                else:
                    dllog_lrate_change = None

                model.zero_grad()

            x, y, num_frames = batch_to_gpu(batch)
            y_pred = model(x, use_gt_durations=True)
            loss, meta = criterion(y_pred, y)

            loss /= args.gradient_accumulation_steps
            meta = {
                k: v / args.gradient_accumulation_steps
                for k, v in meta.items()
            }

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {
                    k: reduce_tensor(v, world_size)
                    for k, v in meta.items()
                }
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.gradient_accumulation_steps == 0:

                train_tblogger.log_grads(total_iter, model)
                if args.amp:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.grad_clip_thresh)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)

                optimizer.step()
                apply_ema_decay(model, ema_model, args.ema_decay)

                iter_stop_time = time.time()
                iter_time = iter_stop_time - iter_start_time
                frames_per_sec = iter_num_frames / iter_time
                epoch_frames_per_sec += frames_per_sec
                epoch_loss += iter_loss
                epoch_num_frames += iter_num_frames
                iter_mel_loss = iter_meta['mel_loss'].item()
                epoch_mel_loss += iter_mel_loss

                DLLogger.log(
                    (epoch, epoch_iter, num_iters),
                    OrderedDict([('train_loss', iter_loss),
                                 ('train_mel_loss', iter_mel_loss),
                                 ('train_frames/s', frames_per_sec),
                                 ('took', iter_time),
                                 ('lrate_change', dllog_lrate_change)]))
                train_tblogger.log_meta(total_iter, iter_meta)

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}

        # Finished epoch
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        DLLogger.log((epoch, ),
                     data=OrderedDict([
                         ('avg_train_loss', epoch_loss / epoch_iter),
                         ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
                         ('avg_train_frames/s', epoch_num_frames / epoch_time),
                         ('took', epoch_time)
                     ]))

        tik = time.time()
        val_loss, meta, num_frames = validate(model,
                                              criterion,
                                              valset,
                                              args.batch_size,
                                              world_size,
                                              collate_fn,
                                              distributed_run,
                                              local_rank,
                                              batch_to_gpu,
                                              use_gt_durations=True)
        tok = time.time()

        DLLogger.log((epoch, ),
                     data=OrderedDict([
                         ('val_loss', val_loss),
                         ('val_mel_loss', meta['mel_loss'].item()),
                         ('val_frames/s', num_frames / (tok - tik)),
                         ('took', tok - tik),
                     ]))
        val_tblogger.log_meta(total_iter, meta)

        if args.ema_decay > 0:
            tik_e = time.time()
            val_loss_e, meta_e, num_frames_e = validate(ema_model,
                                                        criterion,
                                                        valset,
                                                        args.batch_size,
                                                        world_size,
                                                        collate_fn,
                                                        distributed_run,
                                                        local_rank,
                                                        batch_to_gpu,
                                                        use_gt_durations=True)
            tok_e = time.time()

            DLLogger.log(
                (epoch, ),
                data=OrderedDict([
                    ('val_ema_loss', val_loss_e),
                    ('val_ema_mel_loss', meta_e['mel_loss'].item()),
                    ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)),
                    ('took', tok_e - tik_e),
                ]))
            val_ema_tblogger.log_meta(total_iter, meta)

        if (epoch > 0 and args.epochs_per_checkpoint > 0
                and (epoch % args.epochs_per_checkpoint == 0)
                and local_rank == 0):

            checkpoint_path = os.path.join(args.output,
                                           f"FastPitch_checkpoint_{epoch}.pt")
            save_checkpoint(local_rank, model, ema_model, optimizer, epoch,
                            total_iter, model_config, args.amp,
                            checkpoint_path)
        if local_rank == 0:
            DLLogger.flush()

    # Finished training
    DLLogger.log((),
                 data=OrderedDict([
                     ('avg_train_loss', epoch_loss / epoch_iter),
                     ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
                     ('avg_train_frames/s', epoch_num_frames / epoch_time),
                 ]))
    DLLogger.log((),
                 data=OrderedDict([
                     ('val_loss', val_loss),
                     ('val_mel_loss', meta['mel_loss'].item()),
                     ('val_frames/s', num_frames / (tok - tik)),
                 ]))
    if local_rank == 0:
        DLLogger.flush()
Example #21
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(
                args.output_dir, "ckpt_{}.pt".format(global_step)),
                                    map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    # BERT modeling  uses weight sharing between word embedding and prediction decoder.
    # So make sure the storage is pointing properly even after model is moved to device.
    if args.use_habana:
        model.cls.predictions.decoder.weight = model.bert.embeddings.word_embeddings.weight

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.use_habana:
        if args.use_fused_lamb:
            try:
                from hb_custom import FusedLamb
            except ImportError:
                raise ImportError("Please install hbopt.")
            optimizer = FusedLamb(optimizer_grouped_parameters,
                                  lr=args.learning_rate)
        else:
            optimizer = NVLAMB(optimizer_grouped_parameters,
                               lr=args.learning_rate)
    else:
        if torch.cuda.is_available():
            optimizer = FusedLAMB(optimizer_grouped_parameters,
                                  lr=args.learning_rate)
        else:
            optimizer = NVLAMB(optimizer_grouped_parameters,
                               lr=args.learning_rate)

    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale="dynamic",
                                              cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale=args.loss_scale,
                                              cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    model.checkpoint_activations(args.checkpoint_activations)

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(
                    checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter][
                    'step'] = global_step
                checkpoint['optimizer']['param_groups'][iter][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            if not args.use_jit_trace:
                if args.use_habana:
                    model = DDP(model)
                else:
                    model = DDP(model,
                                message_size=250000000,
                                gradient_predivide_factor=get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()],
                           torch.distributed.broadcast, (0, ))
    elif args.n_pu > 1:
        model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def get_optimizer(optimizer_name: str,
                  parameters,
                  learning_rate: float,
                  weight_decay=1e-5,
                  eps=1e-5,
                  **kwargs) -> Optimizer:
    from torch.optim import SGD, Adam, RMSprop, AdamW
    from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger

    if optimizer_name.lower() == "sgd":
        return SGD(parameters,
                   learning_rate,
                   momentum=0.9,
                   nesterov=True,
                   weight_decay=weight_decay,
                   **kwargs)

    if optimizer_name.lower() == "adam":
        return Adam(parameters,
                    learning_rate,
                    weight_decay=weight_decay,
                    eps=eps,
                    **kwargs)  # As Jeremy suggests

    if optimizer_name.lower() == "rms":
        return RMSprop(parameters,
                       learning_rate,
                       weight_decay=weight_decay,
                       **kwargs)

    if optimizer_name.lower() == "adamw":
        return AdamW(parameters,
                     learning_rate,
                     weight_decay=weight_decay,
                     eps=eps,
                     **kwargs)

    if optimizer_name.lower() == "radam":
        return RAdam(parameters,
                     learning_rate,
                     weight_decay=weight_decay,
                     eps=eps,
                     **kwargs)  # As Jeremy suggests

    # Optimizers from torch-optimizer
    if optimizer_name.lower() == "ranger":
        return Ranger(parameters,
                      learning_rate,
                      eps=eps,
                      weight_decay=weight_decay,
                      **kwargs)

    if optimizer_name.lower() == "lamb":
        return Lamb(parameters,
                    learning_rate,
                    eps=eps,
                    weight_decay=weight_decay,
                    **kwargs)

    if optimizer_name.lower() == "diffgrad":
        return DiffGrad(parameters,
                        learning_rate,
                        eps=eps,
                        weight_decay=weight_decay,
                        **kwargs)

    if optimizer_name.lower() == "novograd":
        return NovoGrad(parameters,
                        learning_rate,
                        eps=eps,
                        weight_decay=weight_decay,
                        **kwargs)

    # Optimizers from Apex (Fused version is faster on GPU with tensor cores)
    if optimizer_name.lower() == "fused_lamb":
        from apex.optimizers import FusedLAMB

        return FusedLAMB(parameters,
                         learning_rate,
                         eps=eps,
                         weight_decay=weight_decay,
                         **kwargs)

    if optimizer_name.lower() == "fused_sgd":
        from apex.optimizers import FusedSGD

        return FusedSGD(parameters,
                        learning_rate,
                        momentum=0.9,
                        nesterov=True,
                        weight_decay=weight_decay,
                        **kwargs)

    if optimizer_name.lower() == "fused_adam":
        from apex.optimizers import FusedAdam

        return FusedAdam(parameters,
                         learning_rate,
                         eps=eps,
                         weight_decay=weight_decay,
                         adam_w_mode=True,
                         **kwargs)

    raise ValueError("Unsupported optimizer name " + optimizer_name)
Example #23
0
def create_optimizer(args, model, filter_bias_and_bn=True, freeze_stage=""):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if 'adamw' in opt_lower or 'radam' in opt_lower:
        # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
        # I don't believe they follow the paper or original Torch7 impl which schedules weight
        # decay based on the ratio of current_lr/initial_lr
        weight_decay /= args.lr
    if weight_decay and filter_bias_and_bn:
        if freeze_stage == "stage1":
            stage1_train_attn(model, layer_names=['fc'])
            print('stage1, Freeze layer successfully')
        if freeze_stage == "stage2":
            stage1_train_attn(model,
                              layer_names=['layer3', 'layer4', 'se', 'fc'])
            stage2_train_layer4(model)
            print('stage2, Freeze layer successfully')
        # 对未冻结的层进行权重衰减
        parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    for name, param in model.named_parameters():
        print(name, param.requires_grad)

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=True)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=False)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'adamw':
        optimizer = AdamW(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps,
                          delta=0.1,
                          wd_ratio=0.01,
                          nesterov=True)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=weight_decay,
                         eps=args.opt_eps,
                         nesterov=True)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters,
                                   lr=args.lr,
                                   weight_decay=weight_decay,
                                   eps=args.opt_eps)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  lr=args.lr,
                                  alpha=0.9,
                                  eps=args.opt_eps,
                                  momentum=args.momentum,
                                  weight_decay=weight_decay)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              lr=args.lr,
                              alpha=0.9,
                              eps=args.opt_eps,
                              momentum=args.momentum,
                              weight_decay=weight_decay)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters,
                             lr=args.lr,
                             weight_decay=weight_decay,
                             eps=args.opt_eps)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=True)
    elif opt_lower == 'fusedmomentum':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=False)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=False,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=True,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters,
                              lr=args.lr,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusednovograd':
        optimizer = FusedNovoGrad(parameters,
                                  lr=args.lr,
                                  betas=(0.95, 0.98),
                                  weight_decay=weight_decay,
                                  eps=args.opt_eps)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
Example #24
0
def get_optimizer(
    model: nn.Module,
    optimizer_name: str,
    learning_rate: float,
    weight_decay: float = 1e-5,
    no_weight_decay_on_bias: bool = False,
    eps: float = 1e-5,
    **kwargs,
) -> Optimizer:
    """
    Construct an Optimizer for given model
    Args:
        model: Model to optimize. Only parameters that require_grad will be used
        optimizer_name: Name of the optimizer. Case-insensitive
        learning_rate: Target learning rate (regardless of the scheduler)
        weight_decay: Target weight decay
        no_weight_decay_on_bias: Whether to disable weight decay on bias parameters
        eps: Default epsilon for Adam-like optimizers.
        **kwargs: Additional parameters for optimizer

    Returns:

    """
    from torch.optim import ASGD, SGD, Adam, RMSprop, AdamW
    from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger

    # Optimizer parameter groups
    default_pg, biases_pg = [], []

    for k, v in model.named_parameters():
        if v.requires_grad:
            if str.endswith(k, ".bias"):
                biases_pg.append(v)  # biases
            else:
                default_pg.append(v)  # all else

    if no_weight_decay_on_bias:
        parameters = default_pg
    else:
        parameters = default_pg + biases_pg

    optimizer: Optimizer = None

    if optimizer_name.lower() == "sgd":
        optimizer = SGD(
            parameters,
            lr=learning_rate,
            momentum=0.9,
            nesterov=True,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "asgd":
        optimizer = ASGD(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "adam":
        optimizer = Adam(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=eps,
            **kwargs,
        )
    elif optimizer_name.lower() == "rms":
        optimizer = RMSprop(parameters,
                            learning_rate,
                            weight_decay=weight_decay,
                            **kwargs)
    elif optimizer_name.lower() == "adamw":
        optimizer = AdamW(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=eps,
            **kwargs,
        )
    elif optimizer_name.lower() == "radam":
        optimizer = RAdam(
            parameters,
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=eps,
            **kwargs,
        )
    elif optimizer_name.lower() == "ranger":
        optimizer = Ranger(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "lamb":
        optimizer = Lamb(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "diffgrad":
        optimizer = DiffGrad(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "novograd":
        optimizer = NovoGrad(
            parameters,
            lr=learning_rate,
            eps=eps,
            weight_decay=weight_decay,
            **kwargs,
        )
    elif optimizer_name.lower() == "fused_lamb":
        from apex.optimizers import FusedLAMB

        optimizer = FusedLAMB(parameters,
                              learning_rate,
                              eps=eps,
                              weight_decay=weight_decay,
                              **kwargs)
    elif optimizer_name.lower() == "fused_sgd":
        from apex.optimizers import FusedSGD

        optimizer = FusedSGD(parameters,
                             learning_rate,
                             momentum=0.9,
                             nesterov=True,
                             weight_decay=weight_decay,
                             **kwargs)
    elif optimizer_name.lower() == "fused_adam":
        from apex.optimizers import FusedAdam

        optimizer = FusedAdam(parameters,
                              learning_rate,
                              eps=eps,
                              weight_decay=weight_decay,
                              adam_w_mode=True,
                              **kwargs)
    else:
        raise KeyError(f"Cannot get optimizer by name {optimizer_name}")

    # Currently either no_wd or per-group lr
    if no_weight_decay_on_bias:
        optimizer.add_param_group({"params": biases_pg, "weight_decay": 0})

    return optimizer
Example #25
0
def prepare_optimizers(args, model, checkpoint, global_steps):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.lr_decay == 'poly':
        Scheduler = PolyWarmUpScheduler
    elif args.lr_decay == 'linear':
        Scheduler = LinearWarmUpScheduler
    else:
        raise ValueError('Unknown lr decay "{}"'.format(args.lr_decay))

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)

    if checkpoint is not None:
        if args.resume_step >= args.previous_phase_end_step:
            keys = list(checkpoint['optimizer']['state'].keys())
            # Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_steps
            for i, item in enumerate(checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][i][
                    'step'] = global_steps
                checkpoint['optimizer']['param_groups'][i][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][i][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][i][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])

    lr_schedulers = [
        Scheduler(optimizer,
                  warmup=args.warmup_proportion,
                  total_steps=args.max_steps)
    ]

    scaler = None
    if args.fp16:
        scaler = GradScaler()
        if checkpoint is not None and 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])

    preconditioner = None
    if args.kfac:
        preconditioner = kfac.KFAC(
            model,
            lr=args.learning_rate,
            factor_decay=args.kfac_stat_decay,
            damping=args.kfac_damping,
            kl_clip=args.kfac_kl_clip,
            factor_update_freq=args.kfac_factor_interval,
            inv_update_freq=args.kfac_inv_interval,
            # Skip TrainingHeads which contains the decoder, a Linear module
            # with shape (seq_len, vocab_size), such that it is too large to invert
            skip_layers=args.kfac_skip_layers,
            # BERT calls KFAC very infrequently so no need to optimize for
            # communication. Optimize for memory instead.
            comm_method=kfac.CommMethod.HYBRID_OPT,
            grad_worker_fraction=0.5,
            inv_dtype=torch.float16,
            # Compute the factors and update the running averages during the
            # forward backward pass b/c we are using grad accumulation but
            # not accumulating the input/output data
            accumulate_data=False,
            compute_factor_in_hook=True,
            distribute_layer_factors=False,
            grad_scaler=scaler,
        )

        lrs = Scheduler(preconditioner,
                        warmup=args.warmup_proportion,
                        total_steps=args.max_steps)
        lr_schedulers.append(lrs)

        if checkpoint is not None and 'preconditioner' in checkpoint:
            preconditioner.load_state_dict(checkpoint['preconditioner'])

        if is_main_process():
            logger.info(preconditioner)

    return optimizer, preconditioner, lr_schedulers, scaler
def create_optimizer(args,
                     model,
                     filter_bias_and_bn=True,
                     classification_layer_name=None):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    if 'adamw' in opt_lower or 'radam' in opt_lower:
        # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
        # I don't believe they follow the paper or original Torch7 impl which schedules weight
        # decay based on the ratio of current_lr/initial_lr
        weight_decay /= args.lr

    if weight_decay and filter_bias_and_bn:  # batch norm and bias params
        if classification_layer_name is not None:
            parameters = set_lr_per_params(args, model,
                                           classification_layer_name,
                                           weight_decay)
        else:
            parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.  # reset to 0
    else:
        if classification_layer_name is not None:
            parameters = set_lr_per_params(args,
                                           model,
                                           classification_layer_name,
                                           weight_decay=0)
        else:
            parameters = model.parameters()

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(
        ), 'APEX and CUDA required for fused optimizers'

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=True)
    elif opt_lower == 'momentum':
        optimizer = optim.SGD(parameters,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=weight_decay,
                              nesterov=False)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'adamw':
        optimizer = AdamW(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'nadam':
        optimizer = Nadam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters,
                          lr=args.lr,
                          weight_decay=weight_decay,
                          eps=args.opt_eps,
                          delta=0.1,
                          wd_ratio=0.01,
                          nesterov=True)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters,
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=weight_decay,
                         eps=args.opt_eps,
                         nesterov=True)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters,
                                   lr=args.lr,
                                   weight_decay=weight_decay,
                                   eps=args.opt_eps)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters,
                                  lr=args.lr,
                                  alpha=0.9,
                                  eps=args.opt_eps,
                                  momentum=args.momentum,
                                  weight_decay=weight_decay)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters,
                              lr=args.lr,
                              alpha=0.9,
                              eps=args.opt_eps,
                              momentum=args.momentum,
                              weight_decay=weight_decay)
    elif opt_lower == 'novograd':
        optimizer = NovoGrad(parameters,
                             lr=args.lr,
                             weight_decay=weight_decay,
                             eps=args.opt_eps)
    elif opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters,
                               lr=args.lr,
                               weight_decay=weight_decay,
                               eps=args.opt_eps)
    elif opt_lower == 'fusedsgd':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=True)
    elif opt_lower == 'fusedmomentum':
        optimizer = FusedSGD(parameters,
                             lr=args.lr,
                             momentum=args.momentum,
                             weight_decay=weight_decay,
                             nesterov=False)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=False,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters,
                              lr=args.lr,
                              adam_w_mode=True,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters,
                              lr=args.lr,
                              weight_decay=weight_decay,
                              eps=args.opt_eps)
    elif opt_lower == 'fusednovograd':
        optimizer = FusedNovoGrad(parameters,
                                  lr=args.lr,
                                  betas=(0.95, 0.98),
                                  weight_decay=weight_decay,
                                  eps=args.opt_eps)
    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer
def main():
    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments))
    training_args, dataset_args, collaboration_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
    )

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed)

    config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)

    tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)

    # find latest checkpoint in output_dir
    output_dir = Path(training_args.output_dir)
    logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
    latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime)

    if latest_checkpoint_dir is not None:
        logger.info(f'Loading model from {latest_checkpoint_dir}')
        model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
    else:
        logger.info(f'Training from scratch')
        model = AlbertForPreTraining(config)
        model.resize_token_embeddings(len(tokenizer))

    tokenized_dataset_path = Path(dataset_args.dataset_path)

    tokenized_datasets = load_from_disk(tokenized_dataset_path)

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = FusedLAMB(
        optimizer_grouped_parameters,
        lr=training_args.learning_rate,
        betas=(training_args.adam_beta1, training_args.adam_beta2),
        eps=training_args.adam_epsilon,
    )

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
    )

    trainer = CollaborativeTrainer(
        model=model, args=training_args, collaboration_args=collaboration_args,
        train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
        eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        optimizers=(optimizer, lr_scheduler)
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=latest_checkpoint_dir)
def prepare_model_and_optimizer(args, device):
    global_step = 0
    args.resume_step = 0
    checkpoint = None

    config = BertConfig.from_json_file(args.bert_config_path)
    config.fused_mha = args.fused_mha
    config.fused_gelu_bias = args.fused_gelu_bias
    config.dense_seq_output = args.dense_seq_output
    config.unpad = args.unpad
    config.pad = args.pad
    config.fuse_qkv = not args.disable_fuse_qkv
    config.fuse_scale = not args.disable_fuse_scale
    config.fuse_mask = not args.disable_fuse_mask
    config.fuse_dropout = args.enable_fuse_dropout
    config.apex_softmax = not args.disable_apex_softmax
    config.enable_stream = args.enable_stream
    if config.fuse_mask == True: config.apex_softmax = True
    if config.pad == False: config.enable_stream = True
    if config.unpad == True: config.fused_mha = False

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    # Load from Pyt checkpoint - either given as init_checkpoint, or picked up from output_dir if found
    if args.init_checkpoint is not None or found_resume_checkpoint(args):
        # Prepare model

        model = BertForPreTraining(config)
        if args.init_checkpoint is None: # finding checkpoint in output_dir
            checkpoint_str = "phase2_ckpt_*.pt" if args.phase2 else "phase1_ckpt_*.pt"
            model_names = [f for f in glob.glob(os.path.join(args.output_dir, checkpoint_str))]
            global_step = max([int(x.split('.pt')[0].split('_')[-1].strip()) for x in model_names])
            args.resume_step = global_step #used for throughput computation

            resume_init_checkpoint = os.path.join(args.output_dir, checkpoint_str.replace("*", str(global_step)))
            print("Setting init checkpoint to %s - which is the latest in %s" %(resume_init_checkpoint, args.output_dir))
            checkpoint=torch.load(resume_init_checkpoint, map_location="cpu")
        else:
            checkpoint=torch.load(args.init_checkpoint, map_location="cpu")["model"]

        # Fused MHA requires a remapping of checkpoint parameters
        if config.fused_mha:
            checkpoint_remapped = remap_attn_parameters(checkpoint)
            model.load_state_dict(checkpoint_remapped, strict=False)
        else:
            model.load_state_dict(checkpoint, strict=True)
    else: #Load from TF Checkpoint
        model = BertForPreTraining.from_pretrained(args.init_tf_checkpoint, from_tf=True, config=config)


    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay_rate},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR,
                            value=args.learning_rate, sync=False)
    optimizer = FusedLAMB(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          betas=(args.opt_lamb_beta_1, args.opt_lamb_beta_2))
    mlperf_logger.log_event(key='opt_epsilon', value=optimizer.defaults['eps'],
                            sync=False)
    b1, b2 = optimizer.defaults['betas']
    mlperf_logger.log_event(key='opt_lamb_beta_1', value=b1, sync=False)
    mlperf_logger.log_event(key='opt_lamb_beta_2', value=b2, sync=False)
    mlperf_logger.log_event(key='opt_lamb_weight_decay_rate',
                            value=optimizer.defaults['weight_decay'],
                            sync=False)

    if args.warmup_steps == 0:
        warmup_steps = int(args.max_steps * args.warmup_proportion)
        warmup_start = 0
    else:
        warmup_steps = args.warmup_steps
        warmup_start = args.start_warmup_step
    lr_scheduler = LinearWarmupPolyDecayScheduler(optimizer, start_warmup_steps=warmup_start, warmup_steps=warmup_steps,
                                                  total_steps=args.max_steps, end_learning_rate=0.0, degree=1.0)
    
                           
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic")
        else:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale)
        amp._amp_state.loss_scalers[0]._loss_scale = float(os.getenv("INIT_LOSS_SCALE", 2**20))


    if found_resume_checkpoint(args):
        optimizer.load_state_dict(checkpoint['optimizer']) #restores m,v states (only if resuming checkpoint, not for init_checkpoint and init_tf_checkpoint for now)

        # Restore AMP master parameters          
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            model = DDP(model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )

    return model, optimizer, lr_scheduler, checkpoint, global_step
Example #29
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if args.p_arpabet > 0.0:
        cmudict.initialize(args.cmudict_path, keep_ambiguous=True)

    distributed_run = args.world_size > 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)

    if args.local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)

    log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
    tb_subsets = ['train', 'val']
    if args.ema_decay > 0.0:
        tb_subsets.append('val_ema')

    logger.init(log_fpath,
                args.output,
                enabled=(args.local_rank == 0),
                tb_subsets=tb_subsets)
    logger.parameters(vars(args), tb_subset='train')

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, args.world_size, args.local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    attention_kl_loss = AttentionBinarizationLoss()

    # Store pitch mean/std as params to translate from Hz during inference
    model.pitch_mean[0] = args.pitch_mean
    model.pitch_std[0] = args.pitch_std

    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    start_epoch = [1]
    start_iter = [0]

    assert args.checkpoint_path is None or args.resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(args, model, ema_model, optimizer, scaler, start_epoch,
                        start_iter, model_config, ch_fpath)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    criterion = FastPitchLoss(
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
        attn_loss_scale=args.attn_loss_scale)

    collate_fn = TTSCollate()

    if args.local_rank == 0:
        prepare_tmp(args.pitch_online_dir)

    trainset = TTSDataset(audiopaths_and_text=args.training_files,
                          **vars(args))
    valset = TTSDataset(audiopaths_and_text=args.validation_files,
                        **vars(args))

    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    # 4 workers are optimal on DGX-1 (from epoch 2 onwards)
    train_loader = DataLoader(trainset,
                              num_workers=4,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=True,
                              persistent_workers=True,
                              drop_last=True,
                              collate_fn=collate_fn)

    if args.ema_decay:
        mt_ema_params = init_multi_tensor_ema(model, ema_model)

    model.train()

    bmark_stats = BenchmarkStats()

    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.perf_counter()

        epoch_loss = 0.0
        epoch_mel_loss = 0.0
        epoch_num_frames = 0
        epoch_frames_per_sec = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}
        iter_start_time = time.perf_counter()

        epoch_iter = 0
        num_iters = len(train_loader) // args.grad_accumulation
        for batch in train_loader:

            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1

                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)

                model.zero_grad(set_to_none=True)

            x, y, num_frames = batch_to_gpu(batch)

            with torch.cuda.amp.autocast(enabled=args.amp):
                y_pred = model(x)
                loss, meta = criterion(y_pred, y)

                if (args.kl_loss_start_epoch is not None
                        and epoch >= args.kl_loss_start_epoch):

                    if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
                        print('Begin hard_attn loss')

                    _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
                    binarization_loss = attention_kl_loss(attn_hard, attn_soft)
                    kl_weight = min(
                        (epoch - args.kl_loss_start_epoch) /
                        args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
                    meta['kl_loss'] = binarization_loss.clone().detach(
                    ) * kl_weight
                    loss += kl_weight * binarization_loss

                else:
                    meta['kl_loss'] = torch.zeros_like(loss)
                    kl_weight = 0
                    binarization_loss = 0

                loss /= args.grad_accumulation

            meta = {k: v / args.grad_accumulation for k, v in meta.items()}

            if args.amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {
                    k: reduce_tensor(v, args.world_size)
                    for k, v in meta.items()
                }
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.grad_accumulation == 0:

                logger.log_grads_tb(total_iter, model)
                if args.amp:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)
                    optimizer.step()

                if args.ema_decay > 0.0:
                    apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)

                iter_mel_loss = iter_meta['mel_loss'].item()
                iter_kl_loss = iter_meta['kl_loss'].item()
                iter_time = time.perf_counter() - iter_start_time
                epoch_frames_per_sec += iter_num_frames / iter_time
                epoch_loss += iter_loss
                epoch_num_frames += iter_num_frames
                epoch_mel_loss += iter_mel_loss

                log(
                    (epoch, epoch_iter, num_iters),
                    tb_total_steps=total_iter,
                    subset='train',
                    data=OrderedDict([
                        ('loss', iter_loss), ('mel_loss', iter_mel_loss),
                        ('kl_loss', iter_kl_loss), ('kl_weight', kl_weight),
                        ('frames/s', iter_num_frames / iter_time),
                        ('took', iter_time),
                        ('lrate', optimizer.param_groups[0]['lr'])
                    ]),
                )

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}
                iter_start_time = time.perf_counter()

        # Finished epoch
        epoch_loss /= epoch_iter
        epoch_mel_loss /= epoch_iter
        epoch_time = time.perf_counter() - epoch_start_time

        log(
            (epoch, ),
            tb_total_steps=None,
            subset='train_avg',
            data=OrderedDict([('loss', epoch_loss),
                              ('mel_loss', epoch_mel_loss),
                              ('frames/s', epoch_num_frames / epoch_time),
                              ('took', epoch_time)]),
        )
        bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
                           epoch_time)

        validate(model, epoch, total_iter, criterion, valset, args.batch_size,
                 collate_fn, distributed_run, batch_to_gpu)

        if args.ema_decay > 0:
            validate(ema_model,
                     epoch,
                     total_iter,
                     criterion,
                     valset,
                     args.batch_size,
                     collate_fn,
                     distributed_run,
                     batch_to_gpu,
                     ema=True)

        maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch,
                              total_iter, model_config)
        logger.flush()

    # Finished training
    if len(bmark_stats) > 0:
        log((),
            tb_total_steps=None,
            subset='train_avg',
            data=bmark_stats.get(args.benchmark_epochs_num))

    validate(model, None, total_iter, criterion, valset, args.batch_size,
             collate_fn, distributed_run, batch_to_gpu)
Example #30
0
def run(config, args):

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size
    distributed_run = world_size > 1

    torch.manual_seed(args.seed + local_rank)
    np.random.seed(args.seed + local_rank)

    #    if local_rank == 0:
    #        if not os.path.exists(args.output):
    #            os.makedirs(args.output)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False

    if distributed_run:
        init_distributed(args, world_size, local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')

    if local_rank == 0:
        print("start training")
        print("args", args)
        print("config", config)

    #############################################
    # model
    if local_rank == 0:
        print("load model")
    model = WaveGrad(config).cuda()

    # optimizer amp config
    if local_rank == 0:
        print("configure optimizer and amp")
    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)

    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    elif args.optimizer == 'pytorch':
        optimizer = torch.optim.Adam(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    start_epoch = [1]
    start_iter = [0]

    ################
    #load checkpoint
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
        load_checkpoint(local_rank, model, optimizer, start_epoch, start_iter,
                        config, args.amp, ch_fpath, world_size)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    # dataloader
    ##########################################################
    if local_rank == 0:
        print("load dataset")

    if local_rank == 0:
        print("prepare train dataset")
    train_dataset = AudioDataset(config, training=True)

    # distributed sampler
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(train_dataset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(train_dataset,
                              num_workers=1,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True)

    # ground truth samples

    if local_rank == 0:
        print("prepare test_dataset")
    test_dataset = AudioDataset(config, training=False)
    test_loader = DataLoader(test_dataset, batch_size=1)
    test_batch = test_dataset.sample_test_batch(
        config.training_config.n_samples_to_test)

    # Log ground truth test batch
    if local_rank == 0:
        print("save truth wave and mel")
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    audios = {
        f'audio_{index}/gt': audio
        for index, audio in enumerate(test_batch)
    }
    specs = {
        f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze()
        for index, audio in enumerate(test_batch)
    }

    ####### loop start
    #epoch
    iteration = 0
    model.train()
    val_loss = 0.0
    torch.cuda.synchronize()

    if local_rank == 0:
        print("epoch start")
    for epoch in range(start_epoch, args.epochs + 1):
        tic_epoch = time.time()
        epoch_loss = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        epoch_iter = 0
        #iteration = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps

        model.module.set_new_noise_schedule(  # 1000 default
            init=torch.linspace,
            init_kwargs={
                'steps': config.training_config.training_noise_schedule.n_iter,
                'start': config.training_config.training_noise_schedule.betas_range[0],
                'end': config.training_config.training_noise_schedule.betas_range[1]
            }
        )

        for i, batch in enumerate(train_loader):
            tic_iter = time.time()

            old_lr = optimizer.param_groups[0]['lr']
            adjust_learning_rate(iteration, optimizer, args.learning_rate,
                                 args.warmup_steps)
            new_lr = optimizer.param_groups[0]['lr']

            model.zero_grad()
            batch = batch.cuda()
            mels = mel_fn(batch)

            # Training step
            model.zero_grad()
            loss = model.module.compute_loss(mels, batch)

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
            else:
                reduced_loss = loss.item()
        # if np.isnan(reduced_loss):
        #     raise Exception("loss is NaN")

            iter_loss += reduced_loss

            if args.amp:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            toc_iter = time.time()
            dur_iter = toc_iter - tic_iter
            epoch_loss += iter_loss
            iter_size = len(train_loader)
            dur_epoch_est = iter_size * dur_iter
            if local_rank == 0:
                print(
                    "\nepoch {:4d} | iter {:>12d}  {:>3d}/{:3d} | {:3.2f}s/iter est {:4.2f}s/epoch | losses {:>12.6f} {:>12.6f} LR {:e}--> {:e}"
                    .format(epoch, iteration, i, iter_size, dur_iter,
                            dur_epoch_est, iter_loss, grad_norm, old_lr,
                            new_lr),
                    end='')
            iter_loss = 0
            iteration += 1

        # Finished epoch
        toc_epoch = time.time()
        dur_epoch = toc_epoch - tic_epoch
        if local_rank == 0:
            print("for {}item,   {:4.2f}s/epoch  ".format(
                iter_size, dur_epoch))

        # Test step
        if epoch % config.training_config.test_interval == 0:
            model.module.set_new_noise_schedule(  # 50 for default
                init=torch.linspace,
                init_kwargs={
                    'steps': config.training_config.test_noise_schedule.n_iter,
                    'start': config.training_config.test_noise_schedule.betas_range[0],
                    'end': config.training_config.test_noise_schedule.betas_range[1]
                } )

        if (epoch % args.epochs_per_checkpoint == 0):
            ch_path = os.path.join(args.output,
                                   "WaveGrad_ch_{:d}.pt".format(epoch))
            save_checkpoint(local_rank, model, optimizer, epoch, iteration,
                            config, args.amp, ch_path)