def main(): parser = argparse.ArgumentParser(description='Open Unmix Trainer') # which target do we want to train? # ============================================================================= # parser.add_argument('--target', type=str, default='vocals', # help='target source (will be passed to the dataset)') # # ============================================================================= parser.add_argument('--target', type=str, default='tabla', help='target source (will be passed to the dataset)') # Dataset paramaters parser.add_argument('--dataset', type=str, default="aligned", choices=[ 'musdb', 'aligned', 'sourcefolder', 'trackfolder_var', 'trackfolder_fix' ], help='Name of the dataset.') parser.add_argument('--root', type=str, help='root path of dataset', default='../rec_data_final/') parser.add_argument('--output', type=str, default="../new_models/model_tabla_mtl_ourmix_1", help='provide output path base folder name') #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_mse_pretrain1') #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default="../out_unmix/model_new_data_aug_tabla_mse_pretrain8" ) #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_bce_finetune2') parser.add_argument('--model', type=str, help='Path to checkpoint folder') #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='umxhq') parser.add_argument( '--onset-model', type=str, help='Path to onset detection model weights', default= "/media/Sharedata/rohit/cnn-onset-det/models/apr4/saved_model_0_80mel-0-16000_1ch_44100.pt" ) # Trainig Parameters parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--lr', type=float, default=0.001, help='learning rate, defaults to 1e-3') parser.add_argument( '--patience', type=int, default=140, help='maximum number of epochs to train (default: 140)') parser.add_argument('--lr-decay-patience', type=int, default=80, help='lr decay patience for plateau scheduler') parser.add_argument('--lr-decay-gamma', type=float, default=0.3, help='gamma of learning rate scheduler decay') parser.add_argument('--weight-decay', type=float, default=0.00001, help='weight decay') parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--gamma', type=float, default=0.0, help='weighting of different loss components') parser.add_argument( '--finetune', type=int, default=0, help= 'If true(1), then optimiser states from checkpoint model are reset (required for bce finetuning), false if aim is to resume training from where it was left off' ) parser.add_argument('--onset-thresh', type=float, default=0.3, help='Threshold above which onset is said to occur') parser.add_argument( '--binarise', type=int, default=0, help= 'If=1(true), then target novelty function is made binary, if=0(false), then left as it is' ) parser.add_argument( '--onset-trainable', type=int, default=0, help= 'If=1(true), then onsetCNN will also get trained in finetuning stage, if=0(false) then kept fixed' ) # Model Parameters parser.add_argument('--seq-dur', type=float, default=6.0, help='Sequence duration in seconds' 'value of <=0.0 will use full/variable length') parser.add_argument( '--unidirectional', action='store_true', default=False, help='Use unidirectional LSTM instead of bidirectional') parser.add_argument('--nfft', type=int, default=4096, help='STFT fft size and window size') parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size') # ============================================================================= # parser.add_argument('--nfft', type=int, default=2048, # help='STFT fft size and window size') # parser.add_argument('--nhop', type=int, default=512, # help='STFT hop size') # ============================================================================= parser.add_argument('--n-mels', type=int, default=80, help='Number of bins in mel spectrogram') parser.add_argument( '--hidden-size', type=int, default=512, help='hidden size parameter of dense bottleneck layers') parser.add_argument('--bandwidth', type=int, default=16000, help='maximum model bandwidth in herz') parser.add_argument('--nb-channels', type=int, default=2, help='set number of channels for model (1, 2)') parser.add_argument('--nb-workers', type=int, default=4, help='Number of workers for dataloader.') # Misc Parameters parser.add_argument('--quiet', action='store_true', default=False, help='less verbose during training') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') args, _ = parser.parse_known_args() use_cuda = not args.no_cuda and torch.cuda.is_available() print("Using GPU:", use_cuda) print("Using Torchaudio: ", utils._torchaudio_available()) dataloader_kwargs = { 'num_workers': args.nb_workers, 'pin_memory': True } if use_cuda else {} repo_dir = os.path.abspath(os.path.dirname(__file__)) repo = Repo(repo_dir) commit = repo.head.commit.hexsha[:7] # use jpg or npy torch.manual_seed(args.seed) random.seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") torch.autograd.set_detect_anomaly(True) train_dataset, valid_dataset, args = data.load_datasets(parser, args) print("TRAIN DATASET", train_dataset) print("VALID DATASET", valid_dataset) # create output dir if not exist target_path = Path(args.output) target_path.mkdir(parents=True, exist_ok=True) train_sampler = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **dataloader_kwargs) valid_sampler = torch.utils.data.DataLoader(valid_dataset, batch_size=1, **dataloader_kwargs) if args.model: scaler_mean = None scaler_std = None else: scaler_mean, scaler_std = get_statistics(args, train_dataset) max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft, args.bandwidth) unmix = model_mtl.OpenUnmix_mtl( input_mean=scaler_mean, input_scale=scaler_std, nb_channels=args.nb_channels, hidden_size=args.hidden_size, n_fft=args.nfft, n_hop=args.nhop, max_bin=max_bin, sample_rate=train_dataset.sample_rate).to(device) #Read trained onset detection network (Model through which target spectrogram is passed) detect_onset = model.onsetCNN().to(device) detect_onset.load_state_dict( torch.load(args.onset_model, map_location='cuda:0')) #Model through which separated output is passed # detect_onset_training = model.onsetCNN().to(device) # detect_onset_training.load_state_dict(torch.load(args.onset_model, map_location='cuda:0')) for child in detect_onset.children(): for param in child.parameters(): param.requires_grad = False #If onset trainable is false, then we want to keep the weights of this moel fixed # if (args.onset_trainable == 0): # for child in detect_onset_training.children(): # for param in child.parameters(): # param.requires_grad = False # #FOR CHECKING, REMOVE LATER # for child in detect_onset_training.children(): # for param in child.parameters(): # print(param.requires_grad) optimizer = torch.optim.Adam(unmix.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.lr_decay_gamma, patience=args.lr_decay_patience, cooldown=10) es = utils.EarlyStopping(patience=args.patience) # if a model is specified: resume training if args.model: model_path = Path(args.model).expanduser() with open(Path(model_path, args.target + '.json'), 'r') as stream: results = json.load(stream) target_model_path = Path(model_path, args.target + ".chkpnt") checkpoint = torch.load(target_model_path, map_location=device) unmix.load_state_dict(checkpoint['state_dict']) #Only when onse is trainable and when that finetuning is being resumed from a point where it is left off, then read the onset state_dict # if ((args.onset_trainable==1)and(args.finetune==0)): # detect_onset_training.load_state_dict(checkpoint['onset_state_dict']) # print("Reading saved onset model") # else: # print("Not reading saved onset model") if (args.finetune == 0): optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) # train for another epochs_trained t = tqdm.trange(results['epochs_trained'], results['epochs_trained'] + args.epochs + 1, disable=args.quiet) print("PICKUP WHERE LEFT OFF", args.finetune) train_losses = results['train_loss_history'] train_mse_losses = results['train_mse_loss_history'] train_bce_losses = results['train_bce_loss_history'] valid_losses = results['valid_loss_history'] valid_mse_losses = results['valid_mse_loss_history'] valid_bce_losses = results['valid_bce_loss_history'] train_times = results['train_time_history'] best_epoch = results['best_epoch'] es.best = results['best_loss'] es.num_bad_epochs = results['num_bad_epochs'] else: t = tqdm.trange(1, args.epochs + 1, disable=args.quiet) train_losses = [] train_mse_losses = [] train_bce_losses = [] print("NOT PICKUP WHERE LEFT OFF", args.finetune) valid_losses = [] valid_mse_losses = [] valid_bce_losses = [] train_times = [] best_epoch = 0 #es.best = results['best_loss'] #es.num_bad_epochs = results['num_bad_epochs'] # else start from 0 else: t = tqdm.trange(1, args.epochs + 1, disable=args.quiet) train_losses = [] train_mse_losses = [] train_bce_losses = [] valid_losses = [] valid_mse_losses = [] valid_bce_losses = [] train_times = [] best_epoch = 0 for epoch in t: t.set_description("Training Epoch") end = time.time() train_loss, train_mse_loss, train_bce_loss = train( args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset) #train_mse_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[1] #train_bce_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[2] valid_loss, valid_mse_loss, valid_bce_loss = valid( args, unmix, device, valid_sampler, detect_onset=detect_onset) #valid_mse_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[1] #valid_bce_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[2] scheduler.step(valid_loss) train_losses.append(train_loss) train_mse_losses.append(train_mse_loss) train_bce_losses.append(train_bce_loss) valid_losses.append(valid_loss) valid_mse_losses.append(valid_mse_loss) valid_bce_losses.append(valid_bce_loss) t.set_postfix(train_loss=train_loss, val_loss=valid_loss) stop = es.step(valid_loss) #from matplotlib import pyplot as plt # ============================================================================= # plt.figure(figsize=(16,12)) # plt.subplot(2, 2, 1) # plt.title("Training loss") # plt.plot(train_losses,label="Training") # plt.xlabel("Iterations") # plt.ylabel("Loss") # plt.legend() # plt.show() # #plt.savefig(Path(target_path, "train_plot.pdf")) # # plt.figure(figsize=(16,12)) # plt.subplot(2, 2, 2) # plt.title("Validation loss") # plt.plot(valid_losses,label="Validation") # plt.xlabel("Iterations") # plt.ylabel("Loss") # plt.legend() # plt.show() # #plt.savefig(Path(target_path, "val_plot.pdf")) # ============================================================================= if valid_loss == es.best: best_epoch = epoch utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': unmix.state_dict(), 'best_loss': es.best, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'onset_state_dict': detect_onset.state_dict() }, is_best=valid_loss == es.best, path=target_path, target=args.target) # save params params = { 'epochs_trained': epoch, 'args': vars(args), 'best_loss': es.best, 'best_epoch': best_epoch, 'train_loss_history': train_losses, 'train_mse_loss_history': train_mse_losses, 'train_bce_loss_history': train_bce_losses, 'valid_loss_history': valid_losses, 'valid_mse_loss_history': valid_mse_losses, 'valid_bce_loss_history': valid_bce_losses, 'train_time_history': train_times, 'num_bad_epochs': es.num_bad_epochs, 'commit': commit } with open(Path(target_path, args.target + '.json'), 'w') as outfile: outfile.write(json.dumps(params, indent=4, sort_keys=True)) train_times.append(time.time() - end) if stop: print("Apply Early Stopping") break # ============================================================================= # plt.figure(figsize=(16,12)) # plt.subplot(2, 2, 1) # plt.title("Training loss") # #plt.plot(train_losses,label="Training") # plt.plot(train_losses,label="Training") # plt.xlabel("Iterations") # plt.ylabel("Loss") # plt.legend() # #plt.show() # # plt.figure(figsize=(16,12)) # plt.subplot(2, 2, 2) # plt.title("Validation loss") # plt.plot(valid_losses,label="Validation") # plt.xlabel("Iterations") # plt.ylabel("Loss") # plt.legend() # plt.show() # plt.savefig(Path(target_path, "train_val_plot.pdf")) # #plt.savefig(Path(target_path, "train_plot.pdf")) # ============================================================================= print("TRAINING DONE!!") plt.figure() plt.title("Training loss") plt.plot(train_losses, label="Training") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.savefig(Path(target_path, "train_plot.pdf")) plt.figure() plt.title("Validation loss") plt.plot(valid_losses, label="Validation") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.savefig(Path(target_path, "val_plot.pdf")) plt.figure() plt.title("Training BCE loss") plt.plot(train_bce_losses, label="Training") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.savefig(Path(target_path, "train_bce_plot.pdf")) plt.figure() plt.title("Validation BCE loss") plt.plot(valid_bce_losses, label="Validation") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.savefig(Path(target_path, "val_bce_plot.pdf")) plt.figure() plt.title("Training MSE loss") plt.plot(train_mse_losses, label="Training") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.savefig(Path(target_path, "train_mse_plot.pdf")) plt.figure() plt.title("Validation MSE loss") plt.plot(valid_mse_losses, label="Validation") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.savefig(Path(target_path, "val_mse_plot.pdf"))
def main(): parser = argparse.ArgumentParser(description='Open Unmix Trainer') # which target do we want to train? parser.add_argument('--target', type=str, default='vocals', help='target source (will be passed to the dataset)') # Dataset paramaters parser.add_argument('--dataset', type=str, default="aligned", choices=[ 'musdb', 'aligned', 'sourcefolder', 'trackfolder_var', 'trackfolder_fix' ], help='Name of the dataset.') parser.add_argument('--root', type=str, help='root path of dataset', default='../rec_data_new/') parser.add_argument('--output', type=str, default="../out_unmix/model_new_data_aug_tl", help='provide output path base folder name') #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data') #parser.add_argument('--model', type=str, help='Path to checkpoint folder') parser.add_argument('--model', type=str, help='Path to checkpoint folder', default='umxhq') # Trainig Parameters parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, defaults to 1e-3') parser.add_argument( '--patience', type=int, default=140, help='maximum number of epochs to train (default: 140)') parser.add_argument('--lr-decay-patience', type=int, default=80, help='lr decay patience for plateau scheduler') parser.add_argument('--lr-decay-gamma', type=float, default=0.3, help='gamma of learning rate scheduler decay') parser.add_argument('--weight-decay', type=float, default=0.0000000001, help='weight decay') parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') # Model Parameters parser.add_argument('--seq-dur', type=float, default=6.0, help='Sequence duration in seconds' 'value of <=0.0 will use full/variable length') parser.add_argument( '--unidirectional', action='store_true', default=False, help='Use unidirectional LSTM instead of bidirectional') parser.add_argument('--nfft', type=int, default=4096, help='STFT fft size and window size') parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size') parser.add_argument( '--hidden-size', type=int, default=512, help='hidden size parameter of dense bottleneck layers') parser.add_argument('--bandwidth', type=int, default=16000, help='maximum model bandwidth in herz') parser.add_argument('--nb-channels', type=int, default=2, help='set number of channels for model (1, 2)') parser.add_argument('--nb-workers', type=int, default=4, help='Number of workers for dataloader.') # Misc Parameters parser.add_argument('--quiet', action='store_true', default=False, help='less verbose during training') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') args, _ = parser.parse_known_args() use_cuda = not args.no_cuda and torch.cuda.is_available() print("Using GPU:", use_cuda) print("Using Torchaudio: ", utils._torchaudio_available()) dataloader_kwargs = { 'num_workers': args.nb_workers, 'pin_memory': True } if use_cuda else {} repo_dir = os.path.abspath(os.path.dirname(__file__)) repo = Repo(repo_dir) commit = repo.head.commit.hexsha[:7] # use jpg or npy torch.manual_seed(args.seed) random.seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") train_dataset, valid_dataset, args = data.load_datasets(parser, args) print("TRAIN DATASET", train_dataset) print("VALID DATASET", valid_dataset) # create output dir if not exist target_path = Path(args.output) target_path.mkdir(parents=True, exist_ok=True) train_sampler = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **dataloader_kwargs) valid_sampler = torch.utils.data.DataLoader(valid_dataset, batch_size=1, **dataloader_kwargs) # ============================================================================= # if args.model: # scaler_mean = None # scaler_std = None # # else: # ============================================================================= scaler_mean, scaler_std = get_statistics(args, train_dataset) max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft, args.bandwidth) unmix = model.OpenUnmix(input_mean=scaler_mean, input_scale=scaler_std, nb_channels=args.nb_channels, hidden_size=args.hidden_size, n_fft=args.nfft, n_hop=args.nhop, max_bin=max_bin, sample_rate=train_dataset.sample_rate).to(device) optimizer = torch.optim.Adam(unmix.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.lr_decay_gamma, patience=args.lr_decay_patience, cooldown=10) es = utils.EarlyStopping(patience=args.patience) # if a model is specified: resume training if args.model: # disable progress bar err = io.StringIO() with redirect_stderr(err): unmix = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq', target=args.target, device=device, pretrained=True) # ============================================================================= # model_path = Path(args.model).expanduser() # with open(Path(model_path, args.target + '.json'), 'r') as stream: # results = json.load(stream) # # target_model_path = Path(model_path, args.target + ".chkpnt") # checkpoint = torch.load(target_model_path, map_location=device) # unmix.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) # scheduler.load_state_dict(checkpoint['scheduler']) # # train for another epochs_trained # t = tqdm.trange( # results['epochs_trained'], # results['epochs_trained'] + args.epochs + 1, # disable=args.quiet # ) # train_losses = results['train_loss_history'] # valid_losses = results['valid_loss_history'] # train_times = results['train_time_history'] # best_epoch = results['best_epoch'] # es.best = results['best_loss'] # es.num_bad_epochs = results['num_bad_epochs'] # # else start from 0 # ============================================================================= t = tqdm.trange(1, args.epochs + 1, disable=args.quiet) train_losses = [] valid_losses = [] train_times = [] best_epoch = 0 for epoch in t: t.set_description("Training Epoch") end = time.time() train_loss = train(args, unmix, device, train_sampler, optimizer) valid_loss = valid(args, unmix, device, valid_sampler) scheduler.step(valid_loss) train_losses.append(train_loss) valid_losses.append(valid_loss) t.set_postfix(train_loss=train_loss, val_loss=valid_loss) stop = es.step(valid_loss) from matplotlib import pyplot as plt plt.figure(figsize=(16, 12)) plt.subplot(2, 2, 1) plt.title("Training loss") plt.plot(train_losses, label="Training") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.show() #plt.savefig(Path(target_path, "train_plot.pdf")) plt.figure(figsize=(16, 12)) plt.subplot(2, 2, 2) plt.title("Validation loss") plt.plot(valid_losses, label="Validation") plt.xlabel("Iterations") plt.ylabel("Loss") plt.legend() plt.show() #plt.savefig(Path(target_path, "val_plot.pdf")) if valid_loss == es.best: best_epoch = epoch utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': unmix.state_dict(), 'best_loss': es.best, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, is_best=valid_loss == es.best, path=target_path, target=args.target) # save params params = { 'epochs_trained': epoch, 'args': vars(args), 'best_loss': es.best, 'best_epoch': best_epoch, 'train_loss_history': train_losses, 'valid_loss_history': valid_losses, 'train_time_history': train_times, 'num_bad_epochs': es.num_bad_epochs, 'commit': commit } with open(Path(target_path, args.target + '.json'), 'w') as outfile: outfile.write(json.dumps(params, indent=4, sort_keys=True)) train_times.append(time.time() - end) if stop: print("Apply Early Stopping") break
def main(): parser = argparse.ArgumentParser(description='Open Unmix Trainer') # which target do we want to train? parser.add_argument('--target', type=str, default='vocals', help='target source (will be passed to the dataset)') # experiment tag which will determine output folder in trained models, tensorboard name, etc. parser.add_argument('--tag', type=str) # allow to pass a comment about the experiment parser.add_argument('--comment', type=str, help='comment about the experiment') args, _ = parser.parse_known_args() # Dataset paramaters parser.add_argument('--dataset', type=str, default="musdb", choices=[ 'musdb_lyrics', 'timit_music', 'blended', 'nus', 'nus_train' ], help='Name of the dataset.') parser.add_argument('--root', type=str, help='root path of dataset') parser.add_argument('--output', type=str, default="trained_models/{}/".format(args.tag), help='provide output path base folder name') parser.add_argument('--wst-model', type=str, help='Path to checkpoint folder for warmstart') # Trainig Parameters parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--lr', type=float, default=0.001, help='learning rate, defaults to 1e-3') parser.add_argument('--patience', type=int, default=140, help='maximum number of epochs to train (default: 140)') parser.add_argument('--lr-decay-patience', type=int, default=80, help='lr decay patience for plateau scheduler') parser.add_argument('--lr-decay-gamma', type=float, default=0.3, help='gamma of learning rate scheduler decay') parser.add_argument('--weight-decay', type=float, default=0.00001, help='weight decay') parser.add_argument('--seed', type=int, default=0, metavar='S', help='random seed (default: 0)') parser.add_argument('--alignment-from', type=str, default=None) parser.add_argument('--fake-alignment', action='store_true', default=False) # Model Parameters parser.add_argument('--unidirectional', action='store_true', default=False, help='Use unidirectional LSTM instead of bidirectional') parser.add_argument('--nfft', type=int, default=4096, help='STFT fft size and window size') parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size') parser.add_argument('--hidden-size', type=int, default=512, help='hidden size parameter of dense bottleneck layers') parser.add_argument('--bandwidth', type=int, default=16000, help='maximum model bandwidth in herz') parser.add_argument('--nb-channels', type=int, default=2, help='set number of channels for model (1, 2)') parser.add_argument('--nb-workers', type=int, default=0, help='Number of workers for dataloader.') parser.add_argument('--nb-audio-encoder-layers', type=int, default=2) parser.add_argument('--nb-layers', type=int, default=3) # name of the model class in model.py that should be used parser.add_argument('--architecture', type=str) # select attention type if applicable for selected model parser.add_argument('--attention', type=str) # Misc Parameters parser.add_argument('--quiet', action='store_true', default=False, help='less verbose during training') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') args, _ = parser.parse_known_args() use_cuda = not args.no_cuda and torch.cuda.is_available() print("Using GPU:", use_cuda) print("Using Torchaudio: ", utils._torchaudio_available()) dataloader_kwargs = {'num_workers': args.nb_workers, 'pin_memory': True} if use_cuda else {} writer = SummaryWriter(logdir=os.path.join('tensorboard', args.tag)) # use jpg or npy torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cuda" if use_cuda else "cpu") train_dataset, valid_dataset, args = data.load_datasets(parser, args) # create output dir if not exist target_path = Path(args.output) target_path.mkdir(parents=True, exist_ok=True) train_sampler = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=data.collate_fn, drop_last=True, **dataloader_kwargs ) valid_sampler = torch.utils.data.DataLoader( valid_dataset, batch_size=1, collate_fn=data.collate_fn, **dataloader_kwargs ) if args.wst_model: scaler_mean = None scaler_std = None else: scaler_mean, scaler_std = get_statistics(args, train_dataset) max_bin = utils.bandwidth_to_max_bin( valid_dataset.sample_rate, args.nfft, args.bandwidth ) train_args_dict = vars(args) train_args_dict['max_bin'] = int(max_bin) # added to config train_args_dict['vocabulary_size'] = valid_dataset.vocabulary_size # added to config train_params_dict = copy.deepcopy(vars(args)) # return args as dictionary with no influence on args # add to parameters for model loading but not to config file train_params_dict['scaler_mean'] = scaler_mean train_params_dict['scaler_std'] = scaler_std model_class = model_utls.ModelLoader.get_model(args.architecture) model_to_train = model_class.from_config(train_params_dict) model_to_train.to(device) optimizer = torch.optim.Adam( model_to_train.parameters(), lr=args.lr, weight_decay=args.weight_decay ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.lr_decay_gamma, patience=args.lr_decay_patience, cooldown=10 ) es = utils.EarlyStopping(patience=args.patience) # if a model is specified: resume training if args.wst_model: model_path = Path(os.path.join('trained_models', args.wst_model)).expanduser() with open(Path(model_path, args.target + '.json'), 'r') as stream: results = json.load(stream) target_model_path = Path(model_path, args.target + ".chkpnt") checkpoint = torch.load(target_model_path, map_location=device) model_to_train.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) # train for another arg.epochs t = tqdm.trange( results['epochs_trained'], results['epochs_trained'] + args.epochs + 1, disable=args.quiet ) train_losses = results['train_loss_history'] valid_losses = results['valid_loss_history'] train_times = results['train_time_history'] best_epoch = 0 # else start from 0 else: t = tqdm.trange(1, args.epochs + 1, disable=args.quiet) train_losses = [] valid_losses = [] train_times = [] best_epoch = 0 for epoch in t: t.set_description("Training Epoch") end = time.time() train_loss = train(args, model_to_train, device, train_sampler, optimizer) #valid_loss, sdr_val, sar_val, sir_val = valid(args, model_to_train, device, valid_sampler) valid_loss = valid(args, model_to_train, device, valid_sampler) writer.add_scalar("Training_cost", train_loss, epoch) writer.add_scalar("Validation_cost", valid_loss, epoch) scheduler.step(valid_loss) train_losses.append(train_loss) valid_losses.append(valid_loss) t.set_postfix( train_loss=train_loss, val_loss=valid_loss ) stop = es.step(valid_loss) if valid_loss == es.best: best_epoch = epoch utils.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model_to_train.state_dict(), 'best_loss': es.best, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, is_best=valid_loss == es.best, path=target_path, target=args.target ) # save params params = { 'epochs_trained': epoch, 'args': vars(args), 'best_loss': es.best, 'best_epoch': best_epoch, 'train_loss_history': train_losses, 'valid_loss_history': valid_losses, 'train_time_history': train_times, 'num_bad_epochs': es.num_bad_epochs } with open(Path(target_path, args.target + '.json'), 'w') as outfile: outfile.write(json.dumps(params, indent=4, sort_keys=True)) train_times.append(time.time() - end) if stop: print("Apply Early Stopping") break