def run_models(train_iter, valid_iter, test_iter, num_epochs, device, save_dir, load_checkpoint, fine_tune, dataset, do_condition, use_transformer, early_stopping=True, print_every_x_epoch=1, validate_every_x_epoch=1): # setup or load models, optimizers placement_clstm = PlacementCLSTM(PLACEMENT_CHANNELS, PLACEMENT_FILTERS, PLACEMENT_KERNEL_SIZES, PLACEMENT_POOL_KERNEL, PLACEMENT_POOL_STRIDE, NUM_PLACEMENT_LSTM_LAYERS, PLACEMENT_INPUT_SIZE, HIDDEN_SIZE).to(device) placement_optim = optim.Adam(placement_clstm.parameters(), lr=PLACEMENT_LR) if use_transformer: selection_model = ArrowTransformer(EMBED_DIM, dataset.vocab_size, NUM_TRANSFORMER_LAYERS, MAX_SEQ_LEN, PAD_IDX, TRANSFORMER_DROPOUT, do_condition) else: selection_model = SelectionRNN(NUM_SELECTION_LSTM_LAYERS, SELECTION_INPUT_SIZES[dataset.chart_type], dataset.vocab_size, HIDDEN_SIZE, do_condition).to(device) selection_optim = optim.Adam(selection_model.parameters(), lr=SELECTION_LR) # load model, optimizer states if resuming training best_placement_valid_loss = float('inf') best_placement_precision = 0 best_selection_valid_loss = float('inf') start_epoch = 0 start_epoch_batch = 0 train_clstm = True train_selection = True selection_save = TRANSFORMER_SAVE if use_transformer else SRNN_SAVE run_selection_batch = run_transformer_batch if use_transformer else run_srnn_batch selection_criterion = TransformerLoss(ignore_index=PAD_IDX) if use_transformer else SRNN_CRITERION sub_logdir = datetime.datetime.now().strftime('%m_%d_%y_%H_%M') if load_checkpoint: checkpoint = load_save(load_checkpoint, fine_tune, placement_clstm, selection_model, use_transformer, device) if checkpoint: (start_epoch, start_epoch_batch, best_placement_valid_loss, best_placement_precision, best_selection_valid_loss, train_clstm, train_selection, sub_logdir) = checkpoint writer = SummaryWriter(log_dir=os.path.join(save_dir, 'runs', sub_logdir)) print('Starting training..') for epoch in trange(num_epochs): if epoch < start_epoch: continue print('Epoch: {}'.format(epoch)) epoch_p_loss = 0 epoch_p_precision = 0 epoch_s_loss = 0 # report_memory(device=device, show_tensors=True) for i, batch in enumerate(tqdm(train_iter)): # if resuming from checkpoint, skip batches until starting batch for the epoch if start_epoch_batch > 0: if i + 1 == start_epoch_batch: start_epoch_batch = 0 continue step = epoch * len(train_iter) + i with torch.set_grad_enabled(train_clstm): (placement_loss, placement_acc, placement_precision, clstm_hiddens) = run_placement_batch(placement_clstm, placement_optim, PLACEMENT_CRITERION, batch, device, writer, do_condition, do_train=train_clstm, curr_step=step) with torch.set_grad_enabled(train_selection): (selection_loss, selection_acc) = run_selection_batch(selection_model, selection_optim, selection_criterion, batch, device, clstm_hiddens, do_train=train_selection) epoch_p_loss += placement_loss epoch_p_precision += placement_precision epoch_s_loss += selection_loss writer.add_scalar('loss/train_placement', placement_loss, step) writer.add_scalar('accuracy/train_placement', placement_acc, step) writer.add_scalar('loss/train_selection', selection_loss, step) writer.add_scalar('accuracy/train_selection', selection_acc, step) writer.add_scalar('precision/placement', placement_precision, step) if train_clstm: save_model(placement_clstm, save_dir, CLSTM_SAVE) if train_selection: save_model(selection_model, save_dir, selection_save) save_checkpoint(epoch, i, best_placement_valid_loss, best_placement_precision, best_selection_valid_loss, train_clstm, train_selection, save_dir) epoch_p_loss = epoch_p_loss / len(train_iter) epoch_p_precision = epoch_p_precision / len(train_iter) epoch_s_loss = epoch_s_loss / len(train_iter) if epoch % print_every_x_epoch == 0: print(f'\tAvg. training placement loss per unrolling: {epoch_p_loss:.5f}') print(f'\tAvg. training placement precision: {epoch_p_precision:.5f}') print(f'\tAvg. training selection loss per frame: {epoch_s_loss:.5f}') if epoch % validate_every_x_epoch == 0: (placement_valid_loss, placement_valid_acc, selection_valid_loss, selection_valid_acc, placement_precision) = evaluate(placement_clstm, selection_model, valid_iter, PLACEMENT_CRITERION, selection_criterion, device, writer, epoch / validate_every_x_epoch, do_condition, use_transformer) print(f'\tAvg. validation placement loss per frame: {placement_valid_loss:.5f}') print(f'\tAvg. validation placement precision: {placement_precision:.5f}') print(f'\tAvg. validation selection loss per frame: {selection_valid_loss:.5f}') # track best performing model(s) or save every epoch if early_stopping: better_placement = placement_valid_loss < best_placement_valid_loss #better_placement = placement_precision > best_placement_precision better_selection = selection_valid_loss < best_selection_valid_loss if train_clstm: if better_placement: best_placement_precision = placement_precision best_placement_valid_loss = placement_valid_loss save_model(placement_clstm, save_dir, CLSTM_SAVE) else: print("Placement validation loss increased, stopping CLSTM training") train_clstm = False if train_selection: if better_selection: best_selection_valid_loss = selection_valid_loss save_model(selection_model, save_dir, selection_save) else: print("Selection validation loss increased, stopping selection model training") train_selection = False if not train_clstm and not train_selection: print("Both early stopping criterion met. Stopping early..") break save_checkpoint(epoch + 1, 0, best_placement_valid_loss, best_placement_precision, best_selection_valid_loss, train_clstm, train_selection, save_dir) # evaluate on test set (placement_test_loss, placement_test_acc, selection_test_loss, selection_test_acc, placement_precision) = evaluate(placement_clstm, selection_model, test_iter, PLACEMENT_CRITERION, selection_criterion, device, writer, -1, do_condition, use_transformer) # save training summary stats to json file # load initial summary with open(os.path.join(save_dir, SUMMARY_SAVE), 'r') as f: summary_json = json.loads(f.read()) summary_json = { **summary_json, 'epochs_trained': num_epochs, 'placement_test_loss': placement_test_loss, 'placement_test_accuracy': placement_test_acc, 'placement_test_precision': placement_precision, 'selection_test_loss': selection_test_loss, 'selection_test_accuracy': selection_test_acc, } summary_json = log_training_stats(writer, dataset, summary_json, do_condition, use_transformer) with open(os.path.join(save_dir, SUMMARY_SAVE), 'w') as f: f.write(json.dumps(summary_json, indent=2)) # optimize placement thresholds per level (range) which give highest F2 scores on the valid. set thresholds = optimize_placement_thresholds(placement_clstm, valid_iter, device) with open(os.path.join(save_dir, THRESHOLDS_SAVE), 'w') as f: f.write(json.dumps(thresholds, indent=2))
def save(self, tokenizers, output_dirs): from train_util import save_model save_model(self.encoder, output_dirs.encoder) save_model(self.decoder, output_dirs.decoder)
def run_training(train_iter, valid_iter, test_iter, chart_type, save_dir, placement_modelname, selection_modelname, device, num_epochs=15): # flattened audio feats as input placement_input_size = N_MELS * len(N_FFTS) if placement_modelname == 'logreg': placement_model = baseline.PlacementLogReg( input_size=placement_input_size, output_size=2) elif placement_modelname == 'mlp': placement_model = baseline.PlacementMLP( input_size=placement_input_size, output_size=2) if selection_modelname == 'ngram': ngram_counts = Counter() for json_fp, _, _, _ in train_iter.dataset.chart_ids: with open(json_fp, 'r') as f: attrs = json.loads(f.read()) for chart_attrs in attrs['charts']: notes = [note for _, _, _, note in chart_attrs['notes']] for ngram in get_ngrams(notes): ngram_counts[ngram] += 1 selection_model = baseline.SelectionNGram(ngram_counts) elif selection_modelname == 'ngrammlp': selection_model = baseline.SelectionNGramMLP( SELECTION_VCOAB_SIZES[chart_type], SELECTION_INPUT_SIZES[chart_type]) placement_optim = optim.SGD(placement_model.parameters(), lr=PLACEMENT_LR) selection_optim = optim.SGD(selection_model.parameters(), lr=SELECTION_LR) placement_save = MODEL_SAVENAMES[placement_modelname] selection_save = MODEL_SAVENAMES[selection_modelname] best_placement_valid_loss = float('inf') best_placement_precision = 0 best_selection_valid_loss = float('inf') train_placement, train_selection = True, True print('Starting training..') for epoch in trange(num_epochs): print('Epoch: {}'.format(epoch)) epoch_p_loss = 0 epoch_p_precision = 0 epoch_s_loss = 0 for i, batch in enumerate(tqdm(train_iter)): with torch.set_grad_enabled(train_placement): (placement_loss, placement_acc, placement_precision) = run_placement_batch( placement_model, placement_optim, PLACEMENT_CRITERION, batch, device, do_train=train_placement) with torch.set_grad_enabled(train_selection): (selection_loss, selection_acc) = run_selection_batch(selection_model, selection_optim, SELECTION_CRITERION, batch, device, do_train=train_selection) epoch_p_loss += placement_loss epoch_p_precision += placement_precision epoch_s_loss += selection_loss if train_placement: save_model(placement_model, save_dir, placement_save) if train_selection: save_model(selection_model, save_dir, selection_save) epoch_p_loss = epoch_p_loss / len(train_iter) epoch_p_precision = epoch_p_precision / len(train_iter) epoch_s_loss = epoch_s_loss / len(train_iter) print( f'\tAvg. training placement loss per unrolling: {epoch_p_loss:.5f}' ) print(f'\tAvg. training placement precision: {epoch_p_precision:.5f}') print(f'\tAvg. training selection loss per frame: {epoch_s_loss:.5f}') (placement_valid_loss, placement_valid_acc, selection_valid_loss, selection_valid_acc, placement_precision) = evaluate(placement_model, selection_model, valid_iter, PLACEMENT_CRITERION, SELECTION_CRITERION, device) print( f'\tAvg. validation placement loss per frame: {placement_valid_loss:.5f}' ) print( f'\tAvg. training placement precision: {placement_precision:.5f}') print( f'\tAvg. validation selection loss per frame: {selection_valid_loss:.5f}' ) better_placement = placement_precision > best_placement_precision better_selection = selection_valid_loss < best_selection_valid_loss if train_placement: if better_placement: best_placement_precision = placement_precision best_placement_valid_loss = placement_valid_loss save_model(placement_clstm, save_dir, placement_save) else: print( "Placement validation loss increased, stopping CLSTM training" ) train_placement = False if train_selection: if better_selection: best_selection_valid_loss = selection_valid_loss save_model(selection_rnn, save_dir, selection_save) else: print( "Placement validation loss increased, stopping SRNN training" ) train_selection = False if not train_placement and not train_selection: print("Both early stopping criterion met. Stopping early..") break