class AutoencoderTrainer: """training the autoencoder for the first step of training""" def __init__(self, args): self.args = args self.data = [Dataset(args, domain_path) for domain_path in args.data] self.expPath = args.checkpoint / 'Autoencoder' / args.exp_name if not self.expPath.exists(): self.expPath.mkdir(parents=True, exist_ok=True) self.logger = train_logger(self.args, self.expPath) if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" #seed torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) #modules self.encoder = Encoder(args) self.decoder = WaveNet(args) self.discriminator = ZDiscriminator(args) self.encoder = self.encoder.to(self.device) self.decoder = self.decoder.to(self.device) self.discriminator = self.discriminator.to(self.device) #distributed if args.world_size > 1: self.encoder = torch.nn.parallel.DistributedDataParallel( self.encoder, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) self.discriminator = torch.nn.parallel.DistributedDataParallel( self.discriminator, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) self.decoder = torch.nn.parallel.DistributedDataParallel( self.decoder, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) #losses self.reconstruction_loss = [ LossManager(f'train reconstruction {i}') for i in range(len(self.data)) ] self.discriminator_loss = LossManager('train discriminator') self.total_loss = LossManager('train total') self.reconstruction_val = [ LossManager(f'validation reconstruction {i}') for i in range(len(self.data)) ] self.discriminator_val = LossManager('validation discriminator') self.total_val = LossManager('validation total') #optimizers self.autoenc_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.discriminator_optimizer = optim.Adam( self.discriminator.parameters(), lr=args.lr) #resume training if args.resume: checkpoint_args_file = self.expPath / 'args.pth' checkpoint_args = torch.load(checkpoint_args_file) last_epoch = checkpoint_args[-1] self.start_epoch = last_epoch + 1 checkpoint_state_file = self.expPath / f'lastmodel_{last_epoch}.pth' states = torch.load(args.checkpoint_state_file) self.encoder.load_state_dict(states['encoder_state']) self.decoder.load_state_dict(states['decoder_state']) self.discriminator.load_state_dict(states['discriminator_state']) if (args.load_optimizer): self.autoenc_optimizer.load_state_dict( states['autoenc_optimizer_state']) self.discriminator_optimizer.load_state_dict( states['discriminator_optimizer_state']) self.logger.info('Loaded checkpoint parameters') else: self.start_epoch = 0 #learning rates self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step() def train_epoch(self, epoch): #modules self.encoder.train() self.decoder.train() self.discriminator.train() #losses for lm in self.reconstruction_loss: lm.reset() self.discriminator_loss.reset() self.total_loss.reset() total_batches = self.args.epoch_length // self.args.batch_size with tqdm(total=total_batches, desc=f'Train epoch {epoch}') as train_enum: for batch_num in range(total_batches): if self.args.world_size > 1: dataset_no = self.args.rank else: dataset_no = batch_num % self.args.n_datasets x, x_aug = next(self.data[dataset_no].train_iter) x = x.to(self.device) x_aug = x_aug.to(self.device) x, x_aug = x.float(), x_aug.float() # Train discriminator z = self.encoder(x) z_logits = self.discriminator(z) discriminator_loss = F.cross_entropy( z_logits, torch.tensor([dataset_no] * x.size(0)).long().cuda()).mean() loss = discriminator_loss * self.args.d_weight self.discriminator_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip) self.discriminator_optimizer.step() # Train autoencoder z = self.encoder(x_aug) y = self.decoder(x, z) z_logits = self.discriminator(z) discriminator_loss = -F.cross_entropy( z_logits, torch.tensor( [dataset_no] * x.size(0)).long().cuda()).mean() reconstruction_loss = cross_entropy_loss(y, x) self.reconstruction_loss[dataset_no].add( reconstruction_loss.data.cpu().numpy().mean()) loss = (reconstruction_loss.mean() + self.args.d_weight * discriminator_loss) self.model_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.encoder.parameters(), self.args.grad_clip) clip_grad_value_(self.decoder.parameters(), self.args.grad_clip) self.model_optimizer.step() self.loss_total.add(loss.data.item()) train_enum.set_description( f'Train (loss: {loss.data.item():.2f}) epoch {epoch}') train_enum.update() def validate_epoch(self, epoch): #modules self.encoder.eval() self.decoder.eval() self.discriminator.eval() #losses for lm in self.reconstruction_val: lm.reset() self.discriminator_val.reset() self.total_val.reset() total_batches = self.args.epoch_length // self.args.batch_size // 10 with tqdm(total=total_batches) as valid_enum, torch.no_grad(): for batch_num in range(total_batches): if self.args.world_size > 1: dataset_no = self.args.rank else: dataset_no = batch_num % self.args.n_datasets x, x_aug = next(self.data[dataset_no].valid_iter) x = x.to(self.device) x_aug = x.to(self.device) x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) z_logits = self.discriminator(z) z_classification = torch.max(z_logits, dim=1)[1] z_accuracy = (z_classification == dataset_no).float().mean() self.discriminator_val.add(z_accuracy.data.item()) # discriminator_right = F.cross_entropy(z_logits, dset_num).mean() discriminator_right = F.cross_entropy( z_logits, torch.tensor([dataset_no] * x.size(0)).long().cuda()).mean() recon_loss = cross_entropy_loss(y, x) self.evals_recon[dataset_no].add( recon_loss.data.cpu().numpy().mean()) total_loss = discriminator_right.data.item( ) * self.args.d_lambda + recon_loss.mean().data.item() self.total_val.add(total_loss) valid_enum.set_description( f'Test (loss: {total_loss:.2f}) epoch {epoch}') valid_enum.update() def train(self): best_loss = float('inf') for epoch in range(self.start_epoch, self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.validate_epoch(epoch) train_losses = [self.reconstruction_loss, self.discriminator_loss] val_losses = [self.reconstruction_val, self.discriminator_val] self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Validation loss (%s)', epoch, train_losses, val_losses) mean_loss = self.val_total.epoch_mean() if mean_loss < best_loss: self.save_model(f'bestmodel_{self.args.rank}.pth') best_loss = mean_loss if self.args.save_model: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{self.args.rank}.pth') # if self.args.rank: # torch.save([self.args, epoch], '%s/args.pth' % self.expPath) self.lr_manager.step() self.logger.debug('Ended epoch') def save_model(self, filename): save_to = self.expPath / filename torch.save( { 'encoder_state': self.encoder.module.state_dict(), 'decoder_state': self.decoder.module.state_dict(), 'discriminator_state': self.discriminator.module.state_dict(), 'autoenc_optimizer_state': self.autoenc_optimizer.state_dict(), 'd_optimizer_state': self.discriminator_optimizer.state_dict(), 'dataset': self.args.rank, }, save_to) self.logger.debug(f'Saved model to {save_to}')
train_len = int(1) #my_dataset_size * 0.8) val_len = my_dataset_size - train_len train_set, val_set = torch.utils.data.random_split(my_dataset, [train_len, val_len]) train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True) val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True) ### Featurizer featurizer = MelSpectrogram(MelSpectrogramConfig(), device).to(device) ### Model model = WaveNet(hidden_ch=120, skip_ch=240, num_layers=30, mu=256) model = model.to(device) # wandb wandb.init(project='wavenet-pytorch') wandb.watch(model) print('num of model parameters', count_parameters(model)) ### Optimizer opt = torch.optim.Adam(model.parameters(), lr=3e-4) ### Encoder and decoder for mu-law mu_law_encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=256).to(device) mu_law_decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=256).to(device) ### Train loop for i in tqdm(range(NUM_EPOCHS)): for el in train_loader: wav = el['audio'].to(device) melspec = featurizer(wav)