def main(): ''' Usage: python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 256 -warmup 128000 ''' global C global shapes global Beta parser = argparse.ArgumentParser() parser.add_argument('-data_pkl', default=None) # all-in-1 data pickle or bpe field parser.add_argument('-srn', type=bool, default=False) parser.add_argument('-optimize_c', type=bool, default=False) parser.add_argument('-Beta', type=float, default=1.0) parser.add_argument("-lr", type=float, default=1e-1) parser.add_argument("-scheduler_mode", type=str, default=None) parser.add_argument("-scheduler_factor", type=float, default=0.5) parser.add_argument('-train_path', default=None) # bpe encoded data parser.add_argument('-val_path', default=None) # bpe encoded data parser.add_argument('-epoch', type=int, default=10) parser.add_argument('-b', '--batch_size', type=int, default=2048) parser.add_argument('-d_model', type=int, default=512) parser.add_argument('-d_inner_hid', type=int, default=2048) parser.add_argument('-d_k', type=int, default=64) parser.add_argument('-d_v', type=int, default=64) parser.add_argument('-n_head', type=int, default=8) parser.add_argument('-n_layers', type=int, default=6) parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=4000) parser.add_argument('-dropout', type=float, default=0.1) parser.add_argument('-embs_share_weight', action='store_true') parser.add_argument('-proj_share_weight', action='store_true') parser.add_argument('-log', default=None) parser.add_argument('-save_model', default=None) parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-label_smoothing', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.d_word_vec = opt.d_model Beta = opt.Beta if not opt.log and not opt.save_model: print('No experiment result will be saved.') raise if opt.batch_size < 2048 and opt.n_warmup_steps <= 4000: print('[Warning] The warmup steps may be not enough.\n'\ '(sz_b, warmup) = (2048, 4000) is the official setting.\n'\ 'Using smaller batch w/o longer warmup may cause '\ 'the warmup stage ends with only little data trained.') device = torch.device('cuda' if opt.cuda else 'cpu') #========= Loading Dataset =========# if all((opt.train_path, opt.val_path)): training_data, validation_data = prepare_dataloaders_from_bpe_files( opt, device) elif opt.data_pkl: training_data, validation_data = prepare_dataloaders(opt, device) else: raise print(opt) transformer = Transformer(opt.src_vocab_size, opt.trg_vocab_size, src_pad_idx=opt.src_pad_idx, trg_pad_idx=opt.trg_pad_idx, trg_emb_prj_weight_sharing=opt.proj_share_weight, emb_src_trg_weight_sharing=opt.embs_share_weight, d_k=opt.d_k, d_v=opt.d_v, d_model=opt.d_model, d_word_vec=opt.d_word_vec, d_inner=opt.d_inner_hid, n_layers=opt.n_layers, n_head=opt.n_head, dropout=opt.dropout).to(device) if opt.srn: transformer = migrate_to_srn(transformer) transformer = transformer.to(device) if opt.optimize_c: srn_modules = [ module for module in transformer.modules() if isinstance(module, (SRNLinear, SRNConv2d)) ] sranks = [] shapes = [] for module in srn_modules: W = module.weight.detach() shape_w = W.shape W = W.view(shape_w[0], -1) sranks.append(stable_rank(W).item()) shapes.append(W.shape) # a rule of thump to initialize the target srank with the current srank of the model C = [ Parameter((torch.ones(1) * sranks[i] / min(shapes[i])).view(())) for i in range(len(srn_modules)) ] for i, module in enumerate(srn_modules): C[i].to(device) module.c = C[i] criteria = criteria_ else: criteria = cal_performance optimizer = ScheduledOptim(optim.Adam(transformer.parameters(), lr=1e-2, betas=(0.9, 0.98), eps=1e-09), opt.lr, opt.d_model, opt.n_warmup_steps, mode=opt.scheduler_mode, factor=opt.scheduler_factor, patience=3) train(transformer, training_data, validation_data, optimizer, device, opt, loss=criteria) print("~~~~~~~~~~~~~C~~~~~~~~~~~~~") print(C) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~") print("-----------Model-----------") print(transformer) print("---------------------------") with torch.no_grad(): for pname, p in transformer.named_parameters(): if len(p.shape) > 1: print("...Parameter ", pname, ", srank=", stable_rank(p.view(p.shape[0], -1)).item())
def main(): ''' Main function ''' parser = argparse.ArgumentParser() parser.add_argument('-config', type=str, default='config/rnnt.yaml') parser.add_argument('-load_model', type=str, default=None) parser.add_argument('-fp16_allreduce', action='store_true', default=False, help='use fp16 compression during allreduce') parser.add_argument('-batches_per_allreduce', type=int, default=1, help='number of batches processed locally before ' 'executing allreduce across workers; it multiplies ' 'total batch size.') parser.add_argument( '-num_wokers', type=int, default=0, help='how many subprocesses to use for data loading. ' '0 means that the data will be loaded in the main process') parser.add_argument('-log', type=str, default='train.log') opt = parser.parse_args() configfile = open(opt.config) config = AttrDict(yaml.load(configfile)) global global_step global_step = 0 if hvd.rank() == 0: exp_name = config.data.name if not os.path.isdir(exp_name): os.mkdir(exp_name) logger = init_logger(exp_name + '/' + opt.log) else: logger = None if torch.cuda.is_available(): torch.cuda.set_device(hvd.local_rank()) torch.cuda.manual_seed(config.training.seed) torch.backends.cudnn.deterministic = True else: raise NotImplementedError #========= Build DataLoader =========# train_dataset = AudioDateset(config.data, 'train') train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.data.train.batch_size, sampler=train_sampler) assert train_dataset.vocab_size == config.model.vocab_size #========= Build A Model Or Load Pre-trained Model=========# model = Transformer(config.model) if hvd.rank() == 0: n_params, enc_params, dec_params = count_parameters(model) logger.info('# the number of parameters in the whole model: %d' % n_params) logger.info('# the number of parameters in encoder: %d' % enc_params) logger.info('# the number of parameters in decoder: %d' % dec_params) model.cuda() # define an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9) # Horovod: (optional) compression algorithm. compression = hvd.Compression.fp16 if opt.fp16_allreduce else hvd.Compression.none optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=compression) # load pretrain model if opt.load_model is not None and hvd.rank() == 0: checkpoint = torch.load(opt.load_model) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info('Load pretrainded Model and previous Optimizer!') elif hvd.rank() == 0: init_parameters(model) logger.info('Initialized all parameters!') # Horovod: broadcast parameters & optimizer state. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # define loss function crit = nn.CrossEntropyLoss(ignore_index=0) # create a visualizer if config.training.visualization and hvd.rank() == 0: visualizer = SummaryWriter(exp_name + '/log') logger.info('Created a visualizer.') else: visualizer = None for epoch in range(config.training.epoches): train(epoch, model, crit, optimizer, train_loader, train_sampler, logger, visualizer, config) if hvd.rank() == 0: save_model(epoch, model, optimizer, config, logger) if hvd.rank() == 0: logger.info('Traing Process Finished')