# N_EPOCHS = 10 # CLIP = 1 best_valid_loss = float('inf') for epoch in range(args.epoches): start_time = time.time() train_loss = train(transformer_model, optimizer, criterion, args.gradient_clip, device) valid_loss = evaluate(transformer_model, criterion, device) end_time = time.time() epoch_mins, epoch_secs = epoch_time(start_time, end_time) if train_loss <= best_valid_loss: best_valid_loss = train_loss torch.save( transformer_model.state_dict(), './models/transformer/transformer-model_{}.pt'.format(epoch + 1)) logger.info('Epoch: {:02} | Time: {}m {}s'.format( epoch + 1, epoch_mins, epoch_secs)) logger.info('\tTrain Loss: {:.3f} | Train PPL: {:7.3f}'.format( train_loss, math.exp(train_loss))) logger.info('\t Val. Loss: {:.3f} | Val. PPL: {:7.3f}'.format( valid_loss, math.exp(valid_loss)))
def main(): ''' Main function ''' parser = argparse.ArgumentParser() parser.add_argument('-config', type=str, default='config/rnnt.yaml') parser.add_argument('-load_model', type=str, default=None) parser.add_argument('-fp16_allreduce', action='store_true', default=False, help='use fp16 compression during allreduce') parser.add_argument('-batches_per_allreduce', type=int, default=1, help='number of batches processed locally before ' 'executing allreduce across workers; it multiplies ' 'total batch size.') parser.add_argument( '-num_wokers', type=int, default=0, help='how many subprocesses to use for data loading. ' '0 means that the data will be loaded in the main process') parser.add_argument('-log', type=str, default='train.log') opt = parser.parse_args() configfile = open(opt.config) config = AttrDict(yaml.load(configfile)) global global_step global_step = 0 if hvd.rank() == 0: exp_name = config.data.name if not os.path.isdir(exp_name): os.mkdir(exp_name) logger = init_logger(exp_name + '/' + opt.log) else: logger = None if torch.cuda.is_available(): torch.cuda.set_device(hvd.local_rank()) torch.cuda.manual_seed(config.training.seed) torch.backends.cudnn.deterministic = True else: raise NotImplementedError #========= Build DataLoader =========# train_dataset = AudioDateset(config.data, 'train') train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.data.train.batch_size, sampler=train_sampler) assert train_dataset.vocab_size == config.model.vocab_size #========= Build A Model Or Load Pre-trained Model=========# model = Transformer(config.model) if hvd.rank() == 0: n_params, enc_params, dec_params = count_parameters(model) logger.info('# the number of parameters in the whole model: %d' % n_params) logger.info('# the number of parameters in encoder: %d' % enc_params) logger.info('# the number of parameters in decoder: %d' % dec_params) model.cuda() # define an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9) # Horovod: (optional) compression algorithm. compression = hvd.Compression.fp16 if opt.fp16_allreduce else hvd.Compression.none optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=compression) # load pretrain model if opt.load_model is not None and hvd.rank() == 0: checkpoint = torch.load(opt.load_model) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info('Load pretrainded Model and previous Optimizer!') elif hvd.rank() == 0: init_parameters(model) logger.info('Initialized all parameters!') # Horovod: broadcast parameters & optimizer state. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # define loss function crit = nn.CrossEntropyLoss(ignore_index=0) # create a visualizer if config.training.visualization and hvd.rank() == 0: visualizer = SummaryWriter(exp_name + '/log') logger.info('Created a visualizer.') else: visualizer = None for epoch in range(config.training.epoches): train(epoch, model, crit, optimizer, train_loader, train_sampler, logger, visualizer, config) if hvd.rank() == 0: save_model(epoch, model, optimizer, config, logger) if hvd.rank() == 0: logger.info('Traing Process Finished')