def load_and_setup_model(model_name, parser, checkpoint, fp16_run, cpu_run, forward_is_infer=False): model_parser = models.model_parser(model_name, parser, add_help=False) model_args, _ = model_parser.parse_known_args() model_config = models.get_model_config(model_name, model_args) model = models.get_model(model_name, model_config, cpu_run=cpu_run, forward_is_infer=forward_is_infer) if checkpoint is not None: if cpu_run: state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict'] else: state_dict = torch.load(checkpoint)['state_dict'] if checkpoint_from_distributed(state_dict): state_dict = unwrap_distributed(state_dict) model.load_state_dict(state_dict) if model_name == "WaveGlow": model = model.remove_weightnorm(model) model.eval() if fp16_run: model.half() return model
def main(): parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training') parser = parse_args(parser) args, _ = parser.parse_known_args() if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ: local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) else: local_rank = args.rank world_size = args.world_size distributed_run = world_size > 1 if local_rank == 0: log_file = os.path.join(args.output, args.log_file) DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_file), StdOutBackend(Verbosity.VERBOSE)]) else: DLLogger.init(backends=[]) for k,v in vars(args).items(): DLLogger.log(step="PARAMETER", data={k:v}) DLLogger.log(step="PARAMETER", data={'model_name':'Tacotron2_PyT'}) model_name = args.model_name parser = models.model_parser(model_name, parser) args, _ = parser.parse_known_args() torch.backends.cudnn.enabled = args.cudnn_enabled torch.backends.cudnn.benchmark = args.cudnn_benchmark if distributed_run: init_distributed(args, world_size, local_rank, args.group_name) torch.cuda.synchronize() run_start_time = time.perf_counter() model_config = models.get_model_config(model_name, args) model = models.get_model(model_name, model_config, cpu_run=False, uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight) if distributed_run: model = DDP(model,device_ids=[local_rank],output_device=local_rank) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) scaler = torch.cuda.amp.GradScaler(enabled=args.amp) try: sigma = args.sigma except AttributeError: sigma = None start_epoch = [0] if args.resume_from_last: args.checkpoint_path = get_last_checkpoint_filename(args.output, model_name) if args.checkpoint_path is not "": load_checkpoint(model, optimizer, start_epoch, model_config, args.amp, args.checkpoint_path, local_rank) start_epoch = start_epoch[0] criterion = loss_functions.get_loss_function(model_name, sigma) try: n_frames_per_step = args.n_frames_per_step except AttributeError: n_frames_per_step = None collate_fn = data_functions.get_collate_function( model_name, n_frames_per_step) trainset = data_functions.get_data_loader( model_name, args.dataset_path, args.training_files, args) if distributed_run: train_sampler = DistributedSampler(trainset) shuffle = False else: train_sampler = None shuffle = True train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle, sampler=train_sampler, batch_size=args.batch_size, pin_memory=False, drop_last=True, collate_fn=collate_fn) valset = data_functions.get_data_loader( model_name, args.dataset_path, args.validation_files, args) batch_to_gpu = data_functions.get_batch_to_gpu(model_name) iteration = 0 train_epoch_items_per_sec = 0.0 val_loss = 0.0 num_iters = 0 model.train() for epoch in range(start_epoch, args.epochs): torch.cuda.synchronize() epoch_start_time = time.perf_counter() # used to calculate avg items/sec over epoch reduced_num_items_epoch = 0 train_epoch_items_per_sec = 0.0 num_iters = 0 reduced_loss = 0 # if overflow at the last iteration then do not save checkpoint overflow = False if distributed_run: train_loader.sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): torch.cuda.synchronize() iter_start_time = time.perf_counter() DLLogger.log(step=(epoch, i), data={'glob_iter/iters_per_epoch': str(iteration)+"/"+str(len(train_loader))}) adjust_learning_rate(iteration, epoch, optimizer, args.learning_rate, args.anneal_steps, args.anneal_factor, local_rank) model.zero_grad() x, y, num_items = batch_to_gpu(batch) #AMP upstream autocast with torch.cuda.amp.autocast(enabled=args.amp): y_pred = model(x) loss = criterion(y_pred, y) if distributed_run: reduced_loss = reduce_tensor(loss.data, world_size).item() reduced_num_items = reduce_tensor(num_items.data, 1).item() else: reduced_loss = loss.item() reduced_num_items = num_items.item() if np.isnan(reduced_loss): raise Exception("loss is NaN") DLLogger.log(step=(epoch,i), data={'train_loss': reduced_loss}) num_iters += 1 # accumulate number of items processed in this epoch reduced_num_items_epoch += reduced_num_items if args.amp: scaler.scale(loss).backward() scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.grad_clip_thresh) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.grad_clip_thresh) optimizer.step() torch.cuda.synchronize() iter_stop_time = time.perf_counter() iter_time = iter_stop_time - iter_start_time items_per_sec = reduced_num_items/iter_time train_epoch_items_per_sec += items_per_sec DLLogger.log(step=(epoch, i), data={'train_items_per_sec': items_per_sec}) DLLogger.log(step=(epoch, i), data={'train_iter_time': iter_time}) iteration += 1 torch.cuda.synchronize() epoch_stop_time = time.perf_counter() epoch_time = epoch_stop_time - epoch_start_time DLLogger.log(step=(epoch,), data={'train_items_per_sec': (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)}) DLLogger.log(step=(epoch,), data={'train_loss': reduced_loss}) DLLogger.log(step=(epoch,), data={'train_epoch_time': epoch_time}) val_loss, val_items_per_sec = validate(model, criterion, valset, epoch, iteration, args.batch_size, world_size, collate_fn, distributed_run, local_rank, batch_to_gpu) if (epoch % args.epochs_per_checkpoint == 0) and args.bench_class == "": save_checkpoint(model, optimizer, scaler, epoch, model_config, args.amp, args.output, args.model_name, local_rank, world_size) if local_rank == 0: DLLogger.flush() torch.cuda.synchronize() run_stop_time = time.perf_counter() run_time = run_stop_time - run_start_time DLLogger.log(step=tuple(), data={'run_time': run_time}) DLLogger.log(step=tuple(), data={'val_loss': val_loss}) DLLogger.log(step=tuple(), data={'train_items_per_sec': (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)}) DLLogger.log(step=tuple(), data={'val_items_per_sec': val_items_per_sec}) if local_rank == 0: DLLogger.flush()