def main(rank, args): # Distributed setup if args.distributed: setup_distributed(rank, args.world_size) not_main_rank = args.distributed and rank != 0 logging.info("Start time: %s", datetime.now()) # Explicitly set seed to make sure models created in separate processes # start from same random weights and biases torch.manual_seed(args.seed) # Empty CUDA cache torch.cuda.empty_cache() # Change backend for flac files torchaudio.set_audio_backend("soundfile") # Transforms melkwargs = { "n_fft": args.win_length, "n_mels": args.n_bins, "hop_length": args.hop_length, } sample_rate_original = 16000 if args.type == "mfcc": transforms = torch.nn.Sequential( torchaudio.transforms.MFCC( sample_rate=sample_rate_original, n_mfcc=args.n_bins, melkwargs=melkwargs, ), ) num_features = args.n_bins elif args.type == "waveform": transforms = torch.nn.Sequential(UnsqueezeFirst()) num_features = 1 else: raise ValueError("Model type not supported") if args.normalize: transforms = torch.nn.Sequential(transforms, Normalize()) augmentations = torch.nn.Sequential() if args.freq_mask: augmentations = torch.nn.Sequential( augmentations, torchaudio.transforms.FrequencyMasking( freq_mask_param=args.freq_mask), ) if args.time_mask: augmentations = torch.nn.Sequential( augmentations, torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask), ) # Text preprocessing char_blank = "*" char_space = " " char_apostrophe = "'" labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase language_model = LanguageModel(labels, char_blank, char_space) # Dataset training, validation = split_process_librispeech( [args.dataset_train, args.dataset_valid], [transforms, transforms], language_model, root=args.dataset_root, folder_in_archive=args.dataset_folder_in_archive, ) # Decoder if args.decoder == "greedy": decoder = GreedyDecoder() else: raise ValueError("Selected decoder not supported") # Model model = Wav2Letter( num_classes=language_model.length, input_type=args.type, num_features=num_features, ) if args.jit: model = torch.jit.script(model) if args.distributed: n = torch.cuda.device_count() // args.world_size devices = list(range(rank * n, (rank + 1) * n)) model = model.to(devices[0]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) else: devices = ["cuda" if torch.cuda.is_available() else "cpu"] model = model.to(devices[0], non_blocking=True) model = torch.nn.DataParallel(model) n = count_parameters(model) logging.info("Number of parameters: %s", n) # Optimizer if args.optimizer == "adadelta": optimizer = Adadelta( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, eps=args.eps, rho=args.rho, ) elif args.optimizer == "sgd": optimizer = SGD( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optimizer == "adam": optimizer = Adam( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optimizer == "adamw": optimizer = AdamW( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) else: raise ValueError("Selected optimizer not supported") if args.scheduler == "exponential": scheduler = ExponentialLR(optimizer, gamma=args.gamma) elif args.scheduler == "reduceonplateau": scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3) else: raise ValueError("Selected scheduler not supported") criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank], zero_infinity=False) # Data Loader collate_fn_train = collate_factory(model_length_function, augmentations) collate_fn_valid = collate_factory(model_length_function) loader_training_params = { "num_workers": args.workers, "pin_memory": True, "shuffle": True, "drop_last": True, } loader_validation_params = loader_training_params.copy() loader_validation_params["shuffle"] = False loader_training = DataLoader( training, batch_size=args.batch_size, collate_fn=collate_fn_train, **loader_training_params, ) loader_validation = DataLoader( validation, batch_size=args.batch_size, collate_fn=collate_fn_valid, **loader_validation_params, ) # Setup checkpoint best_loss = 1.0 load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) if args.distributed: torch.distributed.barrier() if load_checkpoint: logging.info("Checkpoint: loading %s", args.checkpoint) checkpoint = torch.load(args.checkpoint) args.start_epoch = checkpoint["epoch"] best_loss = checkpoint["best_loss"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) logging.info("Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]) else: logging.info("Checkpoint: not found") save_checkpoint( { "epoch": args.start_epoch, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, False, args.checkpoint, not_main_rank, ) if args.distributed: torch.distributed.barrier() torch.autograd.set_detect_anomaly(False) for epoch in range(args.start_epoch, args.epochs): logging.info("Epoch: %s", epoch) train_one_epoch( model, criterion, optimizer, scheduler, loader_training, decoder, language_model, devices[0], epoch, args.clip_grad, not_main_rank, not args.reduce_lr_valid, ) loss = evaluate( model, criterion, loader_validation, decoder, language_model, devices[0], epoch, not_main_rank, ) if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss) is_best = loss < best_loss best_loss = min(loss, best_loss) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, is_best, args.checkpoint, not_main_rank, ) logging.info("End time: %s", datetime.now()) if args.distributed: torch.distributed.destroy_process_group()
class MusCapsTrainer(BaseTrainer): def __init__(self, config, logger): super(BaseTrainer, self).__init__() self.config = config self.logger = logger self.device = torch.device(self.config.training.device) self.patience = self.config.training.patience self.lr = self.config.training.lr self.load_dataset() self.build_model() self.build_loss() self.build_optimizer() def load_dataset(self): self.logger.write("Loading dataset") dataset_name = self.config.dataset_config.dataset_name if dataset_name == "audiocaption": train_dataset = AudioCaptionDataset(self.config.dataset_config) val_dataset = AudioCaptionDataset(self.config.dataset_config, "val") else: raise ValueError( "{} dataset is not supported.".format(dataset_name)) self.vocab = train_dataset.vocab self.logger.save_vocab(self.vocab.token_freq) OmegaConf.update(self.config, "model_config.vocab_size", self.vocab.size) self.train_loader = DataLoader( dataset=train_dataset, batch_size=self.config.training.batch_size, shuffle=self.config.training.shuffle, num_workers=self.config.training.num_workers, pin_memory=self.config.training.pin_memory, collate_fn=custom_collate_fn, drop_last=True) self.val_loader = DataLoader( dataset=val_dataset, batch_size=self.config.training.batch_size, shuffle=self.config.training.shuffle, num_workers=self.config.training.num_workers, pin_memory=self.config.training.pin_memory, collate_fn=custom_collate_fn) self.logger.write("Number of training samples: {}".format( train_dataset.__len__())) def build_model(self): self.logger.write("Building model") model_name = self.config.model_config.model_name if model_name == "cnn_lstm_caption": self.model = CNNLSTMCaption(self.config.model_config, self.vocab, self.device, teacher_forcing=True) elif model_name == "cnn_attention_lstm": self.model = AttentionModel(self.config.model_config, self.vocab, self.device, teacher_forcing=True) else: raise ValueError("{} model is not supported.".format(model_name)) if self.model.audio_encoder.pretrained_version is not None and not self.model.finetune: for param in self.model.audio_encoder.feature_extractor.parameters( ): param.requires_grad = False self.model.to(self.device) def count_parameters(self): """ Count trainable parameters in model. """ return sum(p.numel() for p in self.model.parameters() if p.requires_grad) def build_loss(self): self.logger.write("Building loss") loss_name = self.config.model_config.loss if loss_name == "cross_entropy": self.loss = nn.CrossEntropyLoss(ignore_index=self.vocab.PAD_INDEX) else: raise ValueError("{} loss is not supported.".format(loss_name)) self.loss = self.loss.to(self.device) def build_optimizer(self): self.logger.write("Building optimizer") optimizer_name = self.config.training.optimizer if optimizer_name == "adam": self.optimizer = Adam(self.model.parameters(), lr=self.lr) elif optimizer_name == "adadelta": self.optimizer = Adadelta(self.model.parameters(), lr=self.lr) else: raise ValueError( "{} optimizer is not supported.".format(optimizer_name)) def train(self): if os.path.exists(self.logger.checkpoint_path): self.logger.write("Resumed training experiment with id {}".format( self.logger.experiment_id)) self.load_ckp(self.logger.checkpoint_path) else: self.logger.write("Started training experiment with id {}".format( self.logger.experiment_id)) self.start_epoch = 0 # Adaptive learning rate scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=self.patience, verbose=True) k_patience = 0 best_val = np.Inf for epoch in range(self.start_epoch, self.config.training.epochs): epoch_start_time = time.time() train_loss = self.train_epoch(self.train_loader, self.device, is_training=True) val_loss = self.train_epoch_val(self.val_loader, self.device, is_training=False) # Decrease the learning rate after not improving in the validation set scheduler.step(val_loss) # check if val loss has been improving during patience period. If not, stop is_val_improving = scheduler.is_better(val_loss, best_val) if not is_val_improving: k_patience += 1 else: k_patience = 0 if k_patience > self.patience * 2: print("Early Stopping") break best_val = scheduler.best epoch_time = time.time() - epoch_start_time lr = self.optimizer.param_groups[0]['lr'] self.logger.update_training_log(epoch + 1, train_loss, val_loss, epoch_time, lr) checkpoint = { 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict() } # save checkpoint in appropriate path (new or best) self.logger.save_checkpoint(state=checkpoint, is_best=is_val_improving) def load_ckp(self, checkpoint_path): checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.start_epoch = checkpoint['epoch'] def train_epoch(self, data_loader, device, is_training): out_list = [] target_list = [] running_loss = 0.0 n_batches = 0 if is_training: self.model.train() if self.model.audio_encoder.pretrained_version is not None: for module in self.model.audio_encoder.feature_extractor.modules( ): if isinstance(module, nn.BatchNorm2d) or isinstance( module, nn.BatchNorm1d): module.eval() else: self.model.eval() for i, batch in enumerate(data_loader): audio, audio_len, x, x_len = batch target_list.append(x) audio = audio.float().to(device=device) x = x.long().to(device=device) audio_len.to(device=device) out = self.model(audio, audio_len, x, x_len) out_list.append(out) target = x[:, 1:] # target excluding sos token out = out.transpose(1, 2) loss = self.loss(out, target) if is_training: self.optimizer.zero_grad() loss.backward() if self.config.training.clip_gradients: clip_grad_norm_(self.model.parameters(), 12) self.optimizer.step() running_loss += loss.item() n_batches += 1 return running_loss / n_batches def train_epoch_val(self, data_loader, device, is_training=False): with torch.no_grad(): loss = self.train_epoch(data_loader, device, is_training=False) return loss