def main():
    '''
    Usage:
    python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 256 -warmup 128000
    '''
    global C
    global shapes
    global Beta
    parser = argparse.ArgumentParser()

    parser.add_argument('-data_pkl',
                        default=None)  # all-in-1 data pickle or bpe field
    parser.add_argument('-srn', type=bool, default=False)
    parser.add_argument('-optimize_c', type=bool, default=False)
    parser.add_argument('-Beta', type=float, default=1.0)
    parser.add_argument("-lr", type=float, default=1e-1)
    parser.add_argument("-scheduler_mode", type=str, default=None)
    parser.add_argument("-scheduler_factor", type=float, default=0.5)
    parser.add_argument('-train_path', default=None)  # bpe encoded data
    parser.add_argument('-val_path', default=None)  # bpe encoded data

    parser.add_argument('-epoch', type=int, default=10)
    parser.add_argument('-b', '--batch_size', type=int, default=2048)

    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model
    Beta = opt.Beta

    if not opt.log and not opt.save_model:
        print('No experiment result will be saved.')
        raise

    if opt.batch_size < 2048 and opt.n_warmup_steps <= 4000:
        print('[Warning] The warmup steps may be not enough.\n'\
              '(sz_b, warmup) = (2048, 4000) is the official setting.\n'\
              'Using smaller batch w/o longer warmup may cause '\
              'the warmup stage ends with only little data trained.')

    device = torch.device('cuda' if opt.cuda else 'cpu')

    #========= Loading Dataset =========#

    if all((opt.train_path, opt.val_path)):
        training_data, validation_data = prepare_dataloaders_from_bpe_files(
            opt, device)
    elif opt.data_pkl:
        training_data, validation_data = prepare_dataloaders(opt, device)
    else:
        raise

    print(opt)

    transformer = Transformer(opt.src_vocab_size,
                              opt.trg_vocab_size,
                              src_pad_idx=opt.src_pad_idx,
                              trg_pad_idx=opt.trg_pad_idx,
                              trg_emb_prj_weight_sharing=opt.proj_share_weight,
                              emb_src_trg_weight_sharing=opt.embs_share_weight,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_word_vec=opt.d_word_vec,
                              d_inner=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout).to(device)
    if opt.srn:
        transformer = migrate_to_srn(transformer)
        transformer = transformer.to(device)
    if opt.optimize_c:
        srn_modules = [
            module for module in transformer.modules()
            if isinstance(module, (SRNLinear, SRNConv2d))
        ]
        sranks = []
        shapes = []

        for module in srn_modules:
            W = module.weight.detach()
            shape_w = W.shape
            W = W.view(shape_w[0], -1)
            sranks.append(stable_rank(W).item())
            shapes.append(W.shape)

        # a rule of thump to initialize the target srank with the current srank of the model
        C = [
            Parameter((torch.ones(1) * sranks[i] / min(shapes[i])).view(()))
            for i in range(len(srn_modules))
        ]
        for i, module in enumerate(srn_modules):
            C[i].to(device)
            module.c = C[i]
        criteria = criteria_
    else:
        criteria = cal_performance
    optimizer = ScheduledOptim(optim.Adam(transformer.parameters(),
                                          lr=1e-2,
                                          betas=(0.9, 0.98),
                                          eps=1e-09),
                               opt.lr,
                               opt.d_model,
                               opt.n_warmup_steps,
                               mode=opt.scheduler_mode,
                               factor=opt.scheduler_factor,
                               patience=3)

    train(transformer,
          training_data,
          validation_data,
          optimizer,
          device,
          opt,
          loss=criteria)
    print("~~~~~~~~~~~~~C~~~~~~~~~~~~~")
    print(C)
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print("-----------Model-----------")
    print(transformer)
    print("---------------------------")
    with torch.no_grad():
        for pname, p in transformer.named_parameters():
            if len(p.shape) > 1:
                print("...Parameter ", pname, ", srank=",
                      stable_rank(p.view(p.shape[0], -1)).item())
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/rnnt.yaml')
    parser.add_argument('-load_model', type=str, default=None)
    parser.add_argument('-fp16_allreduce',
                        action='store_true',
                        default=False,
                        help='use fp16 compression during allreduce')
    parser.add_argument('-batches_per_allreduce',
                        type=int,
                        default=1,
                        help='number of batches processed locally before '
                        'executing allreduce across workers; it multiplies '
                        'total batch size.')
    parser.add_argument(
        '-num_wokers',
        type=int,
        default=0,
        help='how many subprocesses to use for data loading. '
        '0 means that the data will be loaded in the main process')
    parser.add_argument('-log', type=str, default='train.log')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile))

    global global_step
    global_step = 0

    if hvd.rank() == 0:
        exp_name = config.data.name
        if not os.path.isdir(exp_name):
            os.mkdir(exp_name)
        logger = init_logger(exp_name + '/' + opt.log)
    else:
        logger = None

    if torch.cuda.is_available():
        torch.cuda.set_device(hvd.local_rank())
        torch.cuda.manual_seed(config.training.seed)
        torch.backends.cudnn.deterministic = True
    else:
        raise NotImplementedError

    #========= Build DataLoader =========#
    train_dataset = AudioDateset(config.data, 'train')
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.data.train.batch_size,
        sampler=train_sampler)

    assert train_dataset.vocab_size == config.model.vocab_size

    #========= Build A Model Or Load Pre-trained Model=========#
    model = Transformer(config.model)

    if hvd.rank() == 0:
        n_params, enc_params, dec_params = count_parameters(model)
        logger.info('# the number of parameters in the whole model: %d' %
                    n_params)
        logger.info('# the number of parameters in encoder: %d' % enc_params)
        logger.info('# the number of parameters in decoder: %d' % dec_params)

    model.cuda()

    # define an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    # Horovod: (optional) compression algorithm.
    compression = hvd.Compression.fp16 if opt.fp16_allreduce else hvd.Compression.none

    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression)

    # load pretrain model
    if opt.load_model is not None and hvd.rank() == 0:
        checkpoint = torch.load(opt.load_model)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info('Load pretrainded Model and previous Optimizer!')
    elif hvd.rank() == 0:
        init_parameters(model)
        logger.info('Initialized all parameters!')

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # define loss function
    crit = nn.CrossEntropyLoss(ignore_index=0)

    # create a visualizer
    if config.training.visualization and hvd.rank() == 0:
        visualizer = SummaryWriter(exp_name + '/log')
        logger.info('Created a visualizer.')
    else:
        visualizer = None

    for epoch in range(config.training.epoches):
        train(epoch, model, crit, optimizer, train_loader, train_sampler,
              logger, visualizer, config)

        if hvd.rank() == 0:
            save_model(epoch, model, optimizer, config, logger)

    if hvd.rank() == 0:
        logger.info('Traing Process Finished')