def plot_latent_interpolations(self, attr_str, dim, num_points=10): x1 = torch.linspace(-4, 4.0, num_points) _, _, data_loader = self.dataset.data_loaders(batch_size=1) for sample_id, batch in tqdm(enumerate(data_loader)): if sample_id in [0, 1, 2]: inputs, labels = self.process_batch_data(batch) inputs = to_cuda_variable(inputs) recons, _, _, z, _ = self.model(inputs) recons = torch.sigmoid(recons) z = z.repeat(num_points, 1) z[:, dim] = x1.contiguous() outputs = torch.sigmoid(self.model.decode(z)) # save interpolation save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'latent_interpolations_{attr_str}_{sample_id}.png') save_image(outputs.cpu(), save_filepath, nrow=num_points, pad_value=1.0) # save original image org_save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'original_{sample_id}.png') save_image(inputs.cpu(), org_save_filepath, nrow=1, pad_value=1.0) # save reconstruction recons_save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'recons_{sample_id}.png') save_image(recons.cpu(), recons_save_filepath, nrow=1, pad_value=1.0) if sample_id == 5: break
def train(model: tf.keras.Model, dataset: DatasetCreator): optimizer = tf.keras.optimizers.RMSprop(learning_rate=ModelConfig.LR) loss_fn = tf.losses.SparseCategoricalCrossentropy(from_logits=False) trainer = Trainer(model, optimizer, loss_fn, dataset) if DataConfig.USE_TB: tensorboard = TensorBoard(tb_dir=DataConfig.TB_DIR) # Creates the TensorBoard Graph. The result is not great, but it's a start. tf.summary.trace_on(graph=True, profiler=False) trainer.val_step(tf.expand_dims(tf.zeros(dataset.input_shape), 0), 0) with tensorboard.file_writer.as_default(): tf.summary.trace_export(name=ModelConfig.NETWORK_NAME, step=0) best_loss = 1000 last_checkpoint_epoch = 0 for epoch in range(ModelConfig.MAX_EPOCHS): print(f"\nEpoch {epoch}/{ModelConfig.MAX_EPOCHS}") epoch_start_time = time.time() train_loss, train_acc = trainer.train_epoch() if DataConfig.USE_TB: tensorboard.write_metrics(train_loss, train_acc, epoch) tensorboard.write_lr(ModelConfig.LR, epoch) if (train_loss < best_loss and DataConfig.USE_CHECKPOINT and epoch >= DataConfig.RECORD_DELAY and (epoch - last_checkpoint_epoch) >= DataConfig.CHECKPT_SAVE_FREQ): save_path = os.path.join(DataConfig.CHECKPOINT_DIR, f'train_{epoch}') print(f"Loss improved from {best_loss} to {train_loss}, saving model to {save_path}") best_loss = train_loss last_checkpoint_epoch = epoch model.save(save_path) print(f"\nEpoch loss: {train_loss}, Train accuracy: {train_acc} - Took {time.time() - epoch_start_time:.5f}s") # Validation and (expensive to compute) metrics if epoch % DataConfig.VAL_FREQ == 0 and epoch > DataConfig.RECORD_DELAY: validation_start_time = time.time() val_loss, val_acc = trainer.val_epoch() if DataConfig.USE_TB: tensorboard.write_metrics(val_loss, val_acc, epoch, mode="Validation") # Metrics for the Train dataset imgs, labels = list(dataset.train_dataset.take(1).as_numpy_iterator())[0] predictions = model.predict(imgs) tensorboard.write_predictions(imgs, predictions, labels, epoch, mode="Train") # Metrics for the Validation dataset imgs, labels = list(dataset.val_dataset.take(1).as_numpy_iterator())[0] predictions = model.predict(imgs) tensorboard.write_predictions(imgs, predictions, labels, epoch, mode="Validation") print(f"\nValidation loss: {val_loss}, Validation accuracy: {val_acc}", f"- Took {time.time() - validation_start_time:.5f}s", flush=True)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): trainer = Trainer(use_cuda, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_mse_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.input_shape = [2**15] trainer.target_shape = [2**15] trainer.output_shape = [2**15] net = trainer.load_net(WaveUNet) opt_kwargs = { "adam_betas": ADAM_BETAS, "weight_decay": WEIGHT_DECAY, } # Set net to train on clean speech only as an autoencoder net.skips_enabled = False trainer.test_set.clean_only = True trainer.train_set.clean_only = True # Fiddle with learning rate because autoencoder is not very good w/o skip conns. optimizer = trainer.load_optimizer(net, learning_rate=1e-4, **opt_kwargs) trainer.train(net, 5, optimizer, train_loader, test_loader) optimizer = trainer.load_optimizer(net, learning_rate=1e-5, **opt_kwargs) trainer.train(net, 5, optimizer, train_loader, test_loader) optimizer = trainer.load_optimizer(net, learning_rate=1e-6, **opt_kwargs) trainer.train(net, 5, optimizer, train_loader, test_loader) # Set net to train on noisy speech optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE, **opt_kwargs, ) # net.freeze_encoder() net.skips_enabled = True trainer.test_set.clean_only = False trainer.train_set.clean_only = False trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def main(args): config = tf.ConfigProto( allow_soft_placement=args.allow_soft_placement, gpu_options=tf.GPUOptions(allow_growth=args.allow_gpu_growth)) if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) if args.no_trans_repr: kbqa_model = None trans_sess = None else: # load transferred model params config_path = "%s/config.json" % args.model_dir with open(config_path, 'r') as fr: kbqa_model_config = json.load(fr) trans_graph = tf.Graph() with trans_graph.as_default(): kbqa_model = KbqaModel(**kbqa_model_config) trans_saver = tf.train.Saver() trans_sess = tf.Session(config=config, graph=trans_graph) model_path = '%s/model_best/best.model' % args.model_dir trans_saver.restore(trans_sess, save_path=model_path) if args.no_trans_select: feed_parm = None else: param_path = '%s/detail/param.best.pkl' % args.model_dir with open(param_path, 'rb') as fr: param_dict = pickle.load(fr) if len(param_dict.keys()) == 1: feed_parm = { 'bilinear_mat': param_dict['rm_task/rm_forward/bilinear_mat'] } print("bilinear_mat:", feed_parm['bilinear_mat'].shape) else: feed_parm = { 'fc1_weights': param_dict['rm_task/rm_forward/fc1/weights'], 'fc1_biases': param_dict['rm_task/rm_forward/fc1/biases'], 'fc2_weights': param_dict['rm_task/rm_forward/fc2/weights'], 'fc2_biases': param_dict['rm_task/rm_forward/fc2/biases'] } print("fc1_weights:", feed_parm['fc1_weights'].shape) print("fc1_biases:", feed_parm['fc1_biases'].shape) print("fc2_weights:", feed_parm['fc2_weights'].shape) print("fc2_biases:", feed_parm['fc2_biases'].shape) # load knowledge kd_loader = KnowledgeLoader(args.kd_dir) word_vocab, word_embed = kd_loader.load_vocab(vocab_size=args.vocab_size, embed_dim=args.dim_emb) kd_vocab, kd_embed = kd_loader.load_entity_relation() csk_entity_list = kd_loader.load_csk_entities() main_graph = tf.Graph() with main_graph.as_default(): use_trans_repr = False if args.no_trans_repr else True use_trans_select = False if args.no_trans_select else True use_guiding = False if args.no_use_guiding else True if args.multi_step: model = TransDGModelMultistep(word_embed, kd_embed, feed_parm, use_trans_repr=use_trans_repr, use_trans_select=use_trans_select, use_guiding=use_guiding, vocab_size=args.vocab_size, dim_emb=args.dim_emb, dim_trans=args.dim_trans, cell_class=args.cell_class, num_units=args.num_units, num_layers=args.num_layers, max_length=args.max_dec_len, lr_rate=args.lr_rate, max_grad_norm=args.max_grad_norm, drop_rate=args.drop_rate) else: model = TransDGModel(word_embed, kd_embed, feed_parm, use_trans_repr=use_trans_repr, use_trans_select=use_trans_select, vocab_size=args.vocab_size, dim_emb=args.dim_emb, dim_trans=args.dim_trans, cell_class=args.cell_class, num_units=args.num_units, num_layers=args.num_layers, max_length=args.max_dec_len, lr_rate=args.lr_rate, max_grad_norm=args.max_grad_norm, drop_rate=args.drop_rate) saver = tf.train.Saver(max_to_keep=5) best_saver = tf.train.Saver() sess = tf.Session(config=config, graph=main_graph) if args.mode == 'train': if tf.train.get_checkpoint_state("%s/models" % args.log_dir): model_path = tf.train.latest_checkpoint("%s/models" % args.log_dir) print("model restored from [%s]" % model_path) saver.restore(sess, model_path) else: print("create model with init parameters...") with main_graph.as_default(): sess.run(tf.global_variables_initializer()) model.set_vocabs(sess, word_vocab, kd_vocab) if args.verbose > 0: model.show_parameters() train_chunk_list = [] with open("%s/all_list" % args.data_dir, 'r') as fr: for line in fr: train_chunk_list.append(line.strip()) train_batcher = DataBatcher(data_dir=args.data_dir, file_list=train_chunk_list, batch_size=args.batch_size, num_epoch=args.max_epoch, shuffle=True) # wait for train_batcher queue caching print("Loading data from [%s/all_list]" % args.data_dir) while not train_batcher.full(): time.sleep(5) print("loader queue caching...") valid_loader = DataLoader(batch_size=args.batch_size) valid_path = "%s/valid.pkl" % args.data_dir valid_loader.load_data(file_path=valid_path) # train model trainer = Trainer(model=model, sess=sess, trans_model=kbqa_model, trans_sess=trans_sess, saver=saver, best_saver=best_saver, log_dir=args.log_dir, save_per_step=args.save_per_step) for epoch_idx in range(args.max_epoch): print("Epoch %d:" % (epoch_idx + 1)) trainer.train(train_batcher, valid_loader, epoch_idx=epoch_idx) else: if args.ckpt == 'best.model': model_path = "%s/models/best_model/best.model" % args.log_dir else: model_path = "%s/models/model-%s" % (args.log_dir, args.ckpt) print("model restored from [%s]" % model_path) saver.restore(sess, model_path) test_loader = DataLoader(batch_size=args.batch_size) test_path = "%s/test.pkl" % args.data_dir test_loader.load_data(file_path=test_path) with open('%s/stopwords' % args.kd_dir, 'r') as f: stop_words = json.loads(f.readline()) # test model on test set generator = Generator(model=model, sess=sess, trans_model=kbqa_model, trans_sess=trans_sess, log_dir=args.log_dir, ckpt=args.ckpt, stop_words=stop_words, csk_entities=csk_entity_list) generator.generate(test_loader)
def run_train(args): # read data task_data = TaskData(task="task1") processor = BertProcessor(vocab_path="%s/vocab.txt" % args.bert_model_dir, do_lower_case=True) train_path = "%s/train.data.txt" % args.data_dir train_data = task_data.read_data(data_path=train_path, is_train=True, with_pattern=True, shuffle=True) train_examples = processor.create_examples(lines=train_data, example_type='train') train_features = processor.create_features(examples=train_examples, max_seq_len=args.max_seq_len) train_dataset = processor.create_dataset(train_features, with_pattern=True, is_sorted=args.sorted) train_sampler = SequentialSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, collate_fn=collate_fn) valid_path = "%s/dev.data.txt" % args.data_dir valid_data = task_data.read_data(data_path=valid_path, is_train=True, with_pattern=True, shuffle=False) valid_examples = processor.create_examples(lines=valid_data, example_type='valid') valid_features = processor.create_features(examples=valid_examples, max_seq_len=args.max_seq_len) valid_dataset = processor.create_dataset(valid_features, with_pattern=True) valid_sampler = SequentialSampler(valid_dataset) valid_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.batch_size, collate_fn=collate_fn) # save vocab to log dir if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) processor.tokenizer.save_pretrained(args.log_dir) torch.save(args, "%s/train_args.bin" % (args.log_dir)) logging.info("Training/evaluation parameters %s", args) # set model logging.info("initializing model") if args.resume_path: args.resume_path = Path(args.resume_path) model = BertForClassifier.from_pretrained( args.resume_path, num_labels=task_data.get_num_labels()) else: model = BertForClassifier.from_pretrained( args.bert_model_dir, num_labels=task_data.get_num_labels()) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] t_total = int( len(train_dataloader) / args.gradient_accumulation_steps * args.num_epochs) warmup_steps = int(t_total * args.warmup_proportion) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) model_checkpoint = ModelCheckpoint(checkpoint_dir=args.log_dir, mode=args.mode, save_best_only=args.save_best) # training logging.info("======= Running training =======") logging.info("Num examples = %d" % len(train_examples)) logging.info("Num epochs = %d" % args.num_epochs) logging.info("Gradient accumulation steps = %d" % args.gradient_accumulation_steps) logging.info("Total optimization steps = %d" % t_total) trainer = Trainer( model=model, num_epochs=args.num_epochs, criterion=CrossEntropyLoss(), optimizer=optimizer, scheduler=scheduler, model_checkpoint=model_checkpoint, batch_metrics=[F1Score(task_type='multiclass', average='micro')], epoch_metrics=[F1Score(task_type='multiclass', average='micro')]) trainer.train(train_data=train_dataloader, valid_data=valid_dataloader)
def test_train_with_lr_scheduler(mock_checkpoint): """ Check that training loop runs without crashing, when there is no model and when there is a learning rate scheulder used """ trainer = Trainer(use_cuda=USE_CUDA, wandb_name="my-model") trainer.setup_checkpoints("my-checkpoint", checkpoint_epochs=None) train_loader, test_loader = trainer.load_data_loaders( DummyDataset, batch_size=16, subsample=None, build_output=_build_output, length=128, ) trainer.register_loss_fn(_get_mse_loss) trainer.register_metric_fn(_get_mse_metric, "Loss") trainer.input_shape = [1, 80, 256] trainer.target_shape = [1, 80, 256] trainer.output_shape = [1, 80, 256] net = trainer.load_net( DummyNet, input_shape=(16,) + INPUT_SHAPE, output_shape=(16,) + OUTPUT_SHAPE, use_cuda=USE_CUDA, ) optimizer = trainer.load_optimizer( net, learning_rate=1e-4, adam_betas=[0.9, 0.99], weight_decay=1e-6 ) epochs = 5 # One cycle learning rate steps_per_epoch = len(trainer.train_set) // 16 trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, epochs, 1e-3) mock_checkpoint.save.assert_not_called() trainer.train(net, epochs, optimizer, train_loader, test_loader) mock_checkpoint.save.assert_called_once_with( net, "my-checkpoint", name="my-model", use_wandb=False )
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): trainer = Trainer(use_cuda, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": [MIN_LR, MAX_LR], "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_ce_loss) trainer.register_metric_fn(get_ce_metric, "Loss") trainer.register_metric_fn(get_accuracy_metric, "Accuracy") trainer.input_shape = [1, 80, 256] trainer.output_shape = [15] net = trainer.load_net(SpectralSceneNet) optimizer = trainer.load_optimizer(net, learning_rate=MIN_LR, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY) # One cycle learning rate steps_per_epoch = len(trainer.train_set) // batch_size trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, num_epochs, MAX_LR) trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): # Load loss net loss_net = load_checkpoint(LOSS_NET_CHECKPOINT, use_cuda=use_cuda) loss_net.set_feature_mode() loss_net.eval() feature_loss = AudioFeatureLoss(loss_net, use_cuda=use_cuda) def get_feature_loss(inputs, outputs, targets): return feature_loss(inputs, outputs, targets) def get_feature_loss_metric(inputs, outputs, targets): loss_t = feature_loss(inputs, outputs, targets) return loss_t.data.item() trainer = Trainer(num_epochs, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_feature_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.register_metric_fn(get_feature_loss_metric, "Feature Loss") trainer.input_shape = [2**15] trainer.target_shape = [2**15] trainer.output_shape = [2**15] net = trainer.load_net(WaveUNet) optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY, ) trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def reconstruction_loss(x, x_recons): return Trainer.mean_crossentropy_loss(weights=x_recons, targets=x)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): batch_size = BATCH_SIZE trainer = Trainer(use_cuda, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) train_loader, test_loader = trainer.load_data_loaders(Dataset, batch_size, subsample) trainer.register_loss_fn(get_ce_loss) trainer.register_metric_fn(get_ce_metric, "Loss") trainer.input_shape = [32767] net = trainer.load_net(SceneNet) optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY ) trainer.train(net, num_epochs, optimizer, train_loader, test_loader) # Do a fine tuning run with 1/10th learning rate for 1/3rd epochs. optimizer = trainer.load_optimizer( net, learning_rate=LEARNING_RATE / 10, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY / 10 ) num_epochs = num_epochs // 3 trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
def main(): """ main """ config = model_config() if config.check: config.save_dir = "./tmp/" config.use_gpu = torch.cuda.is_available() and config.gpu >= 0 device = config.gpu torch.cuda.set_device(device) # Data definition train_file = "%s/json.train.txt" % config.data_dir dev_file = "%s/json.dev.txt" % config.data_dir test_file = "%s/json.test.txt" % config.data_dir vocab_file = "%s/train.vocab" % config.data_dir word2index, index2word = load_vocab(vocab_file, config.vocab_size) vocab_size = len(word2index) train_iter = prepare_batcher(train_file, word2index, batch_size=config.batch_size, is_shuffle=True) valid_iter = prepare_batcher(dev_file, word2index, batch_size=config.batch_size, is_shuffle=True) test_iter = prepare_batcher(test_file, word2index, batch_size=config.batch_size, is_shuffle=False) # Model definition model = GLMP(index2word, vocab_size=vocab_size, hidden_size=config.hidden_size, embed_dim=config.embed_size, max_resp_len=config.max_dec_len, n_layers=config.num_layers, hop=config.hop, dropout=config.dropout, teacher_forcing_ratio=config.teacher_forcing_ratio, use_cuda=config.use_gpu, use_record=config.use_record, unk_mask=config.unk_mask) # Testing if config.test and config.ckpt: print(model) model.load(save_dir=config.save_dir, file_prefix=config.ckpt) print("Generating ...") if not os.path.exists(config.output_dir): os.makedirs(config.output_dir) model.generate(test_iter, output_dir=config.output_dir, verbose=True) else: # Save directory if not os.path.exists(config.save_dir): os.makedirs(config.save_dir) # Optimizer definition optimizer = getattr(torch.optim, config.optimizer)(model.parameters(), lr=config.lr) # Learning rate scheduler lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, factor=config.lr_decay, patience=1, verbose=True, min_lr=1e-5) # Logger definition logger = logging.getLogger() logging.basicConfig(level=logging.DEBUG, format="%(message)s") fh = logging.FileHandler(os.path.join(config.save_dir, "train.log")) logger.addHandler(fh) # Save config params_file = os.path.join(config.save_dir, "params.json") with open(params_file, 'w') as fp: json.dump(config.__dict__, fp, indent=4, sort_keys=True) print("Saved params to '{}'".format(params_file)) logger.info(model) # Train logger.info("Training starts ...") trainer = Trainer(model=model, optimizer=optimizer, train_iter=train_iter, valid_iter=valid_iter, logger=logger, valid_metric_name="-loss", num_epochs=config.num_epochs, save_dir=config.save_dir, log_steps=config.log_steps, valid_steps=config.valid_steps, grad_clip=config.grad_clip, lr_scheduler=lr_scheduler) if config.ckpt is not None: trainer.load(save_dir=config.save_dir, file_prefix=config.ckpt) trainer.train() logger.info("Training done!")
def train(runtime, training, logging): # Load feature loss net loss_net = load_checkpoint(LOSS_NET_CHECKPOINT, use_cuda=runtime["cuda"]) loss_net.set_feature_mode(num_layers=6) loss_net.eval() feature_loss = AudioFeatureLoss(loss_net, use_cuda=runtime["cuda"]) def get_feature_loss(inputs, outputs, targets): return feature_loss(inputs, outputs, targets) def get_feature_loss_metric(inputs, outputs, targets): loss_t = feature_loss(inputs, outputs, targets) return loss_t.data.item() batch_size = training["batch_size"] epochs = training["epochs"] subsample = training["subsample"] trainer = Trainer(**runtime) trainer.setup_checkpoints(**logging["checkpoint"]) trainer.setup_wandb(**logging["wandb"], run_info={ "Batch Size": batch_size, "Epochs": epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": [MIN_LR, MAX_LR], "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }) train_loader, test_loader = trainer.load_data_loaders( Dataset, batch_size, subsample) trainer.register_loss_fn(get_feature_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.register_metric_fn(get_feature_loss_metric, "Feature Loss") trainer.input_shape = [1, 80, 256] trainer.target_shape = [1, 80, 256] trainer.output_shape = [1, 80, 256] net = trainer.load_net(SpectralUNet) optimizer = trainer.load_optimizer(net, learning_rate=MIN_LR, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY) steps_per_epoch = len(trainer.train_set) // batch_size trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, epochs, MAX_LR) trainer.train(net, epochs, optimizer, train_loader, test_loader)
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs): trainer = Trainer(num_epochs, wandb_name) trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs) trainer.setup_wandb( WANDB_PROJECT, wandb_name, config={ "Batch Size": batch_size, "Epochs": num_epochs, "Adam Betas": ADAM_BETAS, "Learning Rate": LEARNING_RATE, "Disc Learning Rate": DISC_LEARNING_RATE, "Disc Weight": DISC_WEIGHT, "Weight Decay": WEIGHT_DECAY, "Fine Tuning": False, }, ) # Construct generator network gen_net = trainer.load_net(WaveUNet) gen_optimizer = trainer.load_optimizer( gen_net, learning_rate=LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY, ) train_loader, test_loader = trainer.load_data_loaders( NoisySpeechDataset, batch_size, subsample ) # Construct discriminator network disc_net = trainer.load_net(MelDiscriminatorNet) disc_loss = LeastSquaresLoss(disc_net) disc_optimizer = trainer.load_optimizer( disc_net, learning_rate=DISC_LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY, ) # First, train generator using MSE loss disc_net.freeze() gen_net.unfreeze() trainer.register_loss_fn(get_mse_loss) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.input_shape = [2 ** 15] trainer.target_shape = [2 ** 15] trainer.output_shape = [2 ** 15] trainer.train(gen_net, num_epochs, gen_optimizer, train_loader, test_loader) # Next, train GAN using the output of the generator def get_disc_loss(_, fake_audio, real_audio): """ We want to compare the inputs (real audio) with the generated outout (fake audio) """ return disc_loss.for_discriminator(real_audio, fake_audio) def get_disc_metric(_, fake_audio, real_audio): loss_t = disc_loss.for_discriminator(real_audio, fake_audio) return loss_t.data.item() disc_net.unfreeze() gen_net.freeze() trainer.loss_fns = [] trainer.metric_fns = [] trainer.register_loss_fn(get_disc_loss) trainer.register_metric_fn(get_disc_metric, "Discriminator Loss") trainer.train(gen_net, num_epochs, disc_optimizer, train_loader, test_loader) # Finally, train the generator using the discriminator and MSE loss def get_gen_loss(_, fake_audio, real_audio): return disc_loss.for_generator(real_audio, fake_audio) def get_gen_metric(_, fake_audio, real_audio): loss_t = disc_loss.for_generator(real_audio, fake_audio) return loss_t.data.item() disc_net.freeze() gen_net.unfreeze() trainer.loss_fns = [] trainer.metric_fns = [] trainer.register_loss_fn(get_mse_loss) trainer.register_loss_fn(get_gen_loss, weight=DISC_WEIGHT) trainer.register_metric_fn(get_mse_metric, "Loss") trainer.register_metric_fn(get_gen_metric, "Generator Loss") trainer.train(gen_net, num_epochs, gen_optimizer, train_loader, test_loader)