def test_train(mock_checkpoint): """ Check that training loop runs without crashing, when there is no model """ trainer = Trainer(use_cuda=USE_CUDA, wandb_name="my-model") trainer.setup_checkpoints("my-checkpoint", checkpoint_epochs=None) train_loader, test_loader = trainer.load_data_loaders( DummyDataset, batch_size=16, subsample=None, build_output=_build_output, length=128, ) trainer.register_loss_fn(_get_mse_loss) trainer.register_metric_fn(_get_mse_metric, "Loss") trainer.input_shape = [1, 80, 256] trainer.target_shape = [1, 80, 256] trainer.output_shape = [1, 80, 256] net = trainer.load_net( DummyNet, input_shape=(16,) + INPUT_SHAPE, output_shape=(16,) + OUTPUT_SHAPE, use_cuda=USE_CUDA, ) optimizer = trainer.load_optimizer( net, learning_rate=1e-4, adam_betas=[0.9, 0.99], weight_decay=1e-6 ) epochs = 3 mock_checkpoint.save.assert_not_called() trainer.train(net, epochs, optimizer, train_loader, test_loader) mock_checkpoint.save.assert_called_once_with( net, "my-checkpoint", name="my-model", use_wandb=False )
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): trainer = Trainer(use_cuda, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": [MIN_LR, MAX_LR], "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_ce_loss) trainer.register_metric_fn(get_ce_metric, "Loss") trainer.register_metric_fn(get_accuracy_metric, "Accuracy") trainer.input_shape = [1, 80, 256] trainer.output_shape = [15] net = trainer.load_net(SpectralSceneNet) optimizer = trainer.load_optimizer(net, learning_rate=MIN_LR, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY) # One cycle learning rate steps_per_epoch = len(trainer.train_set) // batch_size trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, num_epochs, MAX_LR) trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): batch_size = BATCH_SIZE trainer = Trainer(use_cuda, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders(Dataset, batch_size, subsample) trainer.register_loss_fn(get_ce_loss) trainer.register_metric_fn(get_ce_metric, "Loss") trainer.input_shape = [32767] net = trainer.load_net(SceneNet) optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY ) trainer.train(net, num_epochs, optimizer, train_loader, test_loader) # Do a fine tuning run with 1/10th learning rate for 1/3rd epochs. optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE / 10, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY / 10 ) num_epochs = num_epochs // 3 trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): trainer = Trainer(use_cuda, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_mse_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.input_shape = [2**15] trainer.target_shape = [2**15] trainer.output_shape = [2**15] net = trainer.load_net(WaveUNet) opt_kwargs = { "adam_betas": ADAM_BETAS, "weight_decay": WEIGHT_DECAY, } # Set net to train on clean speech only as an autoencoder net.skips_enabled = False trainer.test_set.clean_only = True trainer.train_set.clean_only = True # Fiddle with learning rate because autoencoder is not very good w/o skip conns. optimizer = trainer.load_optimizer(net, learning_rate=1e-4, **opt_kwargs) trainer.train(net, 5, optimizer, train_loader, test_loader) optimizer = trainer.load_optimizer(net, learning_rate=1e-5, **opt_kwargs) trainer.train(net, 5, optimizer, train_loader, test_loader) optimizer = trainer.load_optimizer(net, learning_rate=1e-6, **opt_kwargs) trainer.train(net, 5, optimizer, train_loader, test_loader) # Set net to train on noisy speech optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE, **opt_kwargs, ) # net.freeze_encoder() net.skips_enabled = True trainer.test_set.clean_only = False trainer.train_set.clean_only = False trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def train(runtime, training, logging): # Load feature loss net loss_net = load_checkpoint(LOSS_NET_CHECKPOINT, use_cuda=runtime["cuda"]) loss_net.set_feature_mode(num_layers=6) loss_net.eval() feature_loss = AudioFeatureLoss(loss_net, use_cuda=runtime["cuda"]) def get_feature_loss(inputs, outputs, targets): return feature_loss(inputs, outputs, targets) def get_feature_loss_metric(inputs, outputs, targets): loss_t = feature_loss(inputs, outputs, targets) return loss_t.data.item() batch_size = training["batch_size"] epochs = training["epochs"] subsample = training["subsample"] trainer = Trainer(**runtime) trainer.setup_checkpoints(**logging["checkpoint"]) trainer.setup_wandb(**logging["wandb"], run_info={ "Batch Size": batch_size, "Epochs": epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": [MIN_LR, MAX_LR], "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_feature_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.register_metric_fn(get_feature_loss_metric, "Feature Loss") trainer.input_shape = [1, 80, 256] trainer.target_shape = [1, 80, 256] trainer.output_shape = [1, 80, 256] net = trainer.load_net(SpectralUNet) optimizer = trainer.load_optimizer(net, learning_rate=MIN_LR, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY) steps_per_epoch = len(trainer.train_set) // batch_size trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, epochs, MAX_LR) trainer.train(net, epochs, optimizer, train_loader, test_loader)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): # Load loss net loss_net = load_checkpoint(LOSS_NET_CHECKPOINT, use_cuda=use_cuda) loss_net.set_feature_mode() loss_net.eval() feature_loss = AudioFeatureLoss(loss_net, use_cuda=use_cuda) def get_feature_loss(inputs, outputs, targets): return feature_loss(inputs, outputs, targets) def get_feature_loss_metric(inputs, outputs, targets): loss_t = feature_loss(inputs, outputs, targets) return loss_t.data.item() trainer = Trainer(num_epochs, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_feature_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.register_metric_fn(get_feature_loss_metric, "Feature Loss") trainer.input_shape = [2**15] trainer.target_shape = [2**15] trainer.output_shape = [2**15] net = trainer.load_net(WaveUNet) optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY, ) trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): trainer = Trainer(num_epochs, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Disc Learning Rate": DISC_LEARNING_RATE, "Disc Weight": DISC_WEIGHT, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) # Construct generator network gen_net = trainer.load_net(WaveUNet) gen_optimizer = trainer.load_optimizer( gen_net, learning_rate=LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY, ) train_loader, test_loader = trainer.load_data_loaders( NoisySpeechDataset, batch_size, subsample ) # Construct discriminator network disc_net = trainer.load_net(MelDiscriminatorNet) disc_loss = LeastSquaresLoss(disc_net) disc_optimizer = trainer.load_optimizer( disc_net, learning_rate=DISC_LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY, ) # First, train generator using MSE loss disc_net.freeze() gen_net.unfreeze() trainer.register_loss_fn(get_mse_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.input_shape = [2 ** 15] trainer.target_shape = [2 ** 15] trainer.output_shape = [2 ** 15] trainer.train(gen_net, num_epochs, gen_optimizer, train_loader, test_loader) # Next, train GAN using the output of the generator def get_disc_loss(_, fake_audio, real_audio): """ We want to compare the inputs (real audio) with the generated outout (fake audio) """ return disc_loss.for_discriminator(real_audio, fake_audio) def get_disc_metric(_, fake_audio, real_audio): loss_t = disc_loss.for_discriminator(real_audio, fake_audio) return loss_t.data.item() disc_net.unfreeze() gen_net.freeze() trainer.loss_fns = [] trainer.metric_fns = [] trainer.register_loss_fn(get_disc_loss) trainer.register_metric_fn(get_disc_metric, "Discriminator Loss") trainer.train(gen_net, num_epochs, disc_optimizer, train_loader, test_loader) # Finally, train the generator using the discriminator and MSE loss def get_gen_loss(_, fake_audio, real_audio): return disc_loss.for_generator(real_audio, fake_audio) def get_gen_metric(_, fake_audio, real_audio): loss_t = disc_loss.for_generator(real_audio, fake_audio) return loss_t.data.item() disc_net.freeze() gen_net.unfreeze() trainer.loss_fns = [] trainer.metric_fns = [] trainer.register_loss_fn(get_mse_loss) trainer.register_loss_fn(get_gen_loss, weight=DISC_WEIGHT) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.register_metric_fn(get_gen_metric, "Generator Loss") trainer.train(gen_net, num_epochs, gen_optimizer, train_loader, test_loader)