Пример #1
0
def run_models(train_iter, valid_iter, test_iter, num_epochs, device, save_dir, load_checkpoint, fine_tune, 
               dataset, do_condition, use_transformer, early_stopping=True, print_every_x_epoch=1, validate_every_x_epoch=1):
    # setup or load models, optimizers
    placement_clstm = PlacementCLSTM(PLACEMENT_CHANNELS, PLACEMENT_FILTERS, PLACEMENT_KERNEL_SIZES,
                                     PLACEMENT_POOL_KERNEL, PLACEMENT_POOL_STRIDE, NUM_PLACEMENT_LSTM_LAYERS,
                                     PLACEMENT_INPUT_SIZE, HIDDEN_SIZE).to(device)
    placement_optim = optim.Adam(placement_clstm.parameters(), lr=PLACEMENT_LR)

    if use_transformer:
        selection_model = ArrowTransformer(EMBED_DIM, dataset.vocab_size, NUM_TRANSFORMER_LAYERS, MAX_SEQ_LEN, 
                                           PAD_IDX, TRANSFORMER_DROPOUT, do_condition)
    else:
        selection_model = SelectionRNN(NUM_SELECTION_LSTM_LAYERS, SELECTION_INPUT_SIZES[dataset.chart_type], 
                                       dataset.vocab_size, HIDDEN_SIZE, do_condition).to(device)
    selection_optim = optim.Adam(selection_model.parameters(), lr=SELECTION_LR)

    # load model, optimizer states if resuming training
    best_placement_valid_loss = float('inf')
    best_placement_precision = 0
    best_selection_valid_loss = float('inf')
    start_epoch = 0
    start_epoch_batch = 0
    train_clstm = True
    train_selection = True
    selection_save = TRANSFORMER_SAVE if use_transformer else SRNN_SAVE
    run_selection_batch = run_transformer_batch if use_transformer else run_srnn_batch
    selection_criterion = TransformerLoss(ignore_index=PAD_IDX) if use_transformer else SRNN_CRITERION
    sub_logdir = datetime.datetime.now().strftime('%m_%d_%y_%H_%M')
  
    if load_checkpoint:
        checkpoint = load_save(load_checkpoint, fine_tune, placement_clstm, selection_model, use_transformer, device)
        if checkpoint:
            (start_epoch, start_epoch_batch, best_placement_valid_loss, best_placement_precision,
             best_selection_valid_loss, train_clstm, train_selection, sub_logdir) = checkpoint

    writer = SummaryWriter(log_dir=os.path.join(save_dir, 'runs', sub_logdir))

    print('Starting training..')
    for epoch in trange(num_epochs):
        if epoch < start_epoch:
            continue

        print('Epoch: {}'.format(epoch))
        epoch_p_loss = 0
        epoch_p_precision = 0
        epoch_s_loss = 0
        # report_memory(device=device, show_tensors=True)

        for i, batch in enumerate(tqdm(train_iter)):
            # if resuming from checkpoint, skip batches until starting batch for the epoch
            if start_epoch_batch > 0:
                if i + 1 == start_epoch_batch:
                    start_epoch_batch = 0
                continue
            
            step = epoch * len(train_iter) + i

            with torch.set_grad_enabled(train_clstm):
                (placement_loss,
                 placement_acc,
                 placement_precision,
                 clstm_hiddens) = run_placement_batch(placement_clstm, placement_optim, PLACEMENT_CRITERION, 
                                                      batch, device, writer, do_condition, do_train=train_clstm, 
                                                      curr_step=step)

            with torch.set_grad_enabled(train_selection):
                (selection_loss,
                 selection_acc) = run_selection_batch(selection_model, selection_optim, selection_criterion,
                                                      batch, device, clstm_hiddens, do_train=train_selection)

            epoch_p_loss += placement_loss
            epoch_p_precision += placement_precision
            epoch_s_loss += selection_loss

            writer.add_scalar('loss/train_placement', placement_loss, step)
            writer.add_scalar('accuracy/train_placement', placement_acc, step)
            writer.add_scalar('loss/train_selection', selection_loss, step)
            writer.add_scalar('accuracy/train_selection', selection_acc, step)
            writer.add_scalar('precision/placement', placement_precision, step)

            if train_clstm:
                save_model(placement_clstm, save_dir, CLSTM_SAVE)
            if train_selection:
                save_model(selection_model, save_dir, selection_save)       

            save_checkpoint(epoch, i, best_placement_valid_loss, best_placement_precision,
                            best_selection_valid_loss, train_clstm, train_selection, save_dir)

        epoch_p_loss = epoch_p_loss / len(train_iter)
        epoch_p_precision = epoch_p_precision / len(train_iter)
        epoch_s_loss = epoch_s_loss / len(train_iter)

        if epoch % print_every_x_epoch == 0:
            print(f'\tAvg. training placement loss per unrolling: {epoch_p_loss:.5f}')
            print(f'\tAvg. training placement precision: {epoch_p_precision:.5f}')
            print(f'\tAvg. training selection loss per frame: {epoch_s_loss:.5f}')

        if epoch % validate_every_x_epoch == 0:
            (placement_valid_loss,
             placement_valid_acc,
             selection_valid_loss,
             selection_valid_acc,
             placement_precision) = evaluate(placement_clstm, selection_model, valid_iter, 
                                             PLACEMENT_CRITERION, selection_criterion, device, 
                                             writer, epoch / validate_every_x_epoch, do_condition, use_transformer)
                
            print(f'\tAvg. validation placement loss per frame: {placement_valid_loss:.5f}')
            print(f'\tAvg. validation placement precision: {placement_precision:.5f}')
            print(f'\tAvg. validation selection loss per frame: {selection_valid_loss:.5f}')

            # track best performing model(s) or save every epoch
            if early_stopping:
                better_placement = placement_valid_loss < best_placement_valid_loss
                #better_placement = placement_precision > best_placement_precision
                better_selection = selection_valid_loss < best_selection_valid_loss

                if train_clstm:
                    if better_placement:
                        best_placement_precision = placement_precision				
                        best_placement_valid_loss = placement_valid_loss
                        save_model(placement_clstm, save_dir, CLSTM_SAVE)
                    else:
                        print("Placement validation loss increased, stopping CLSTM training")
                        train_clstm = False

                if train_selection:
                    if better_selection:
                        best_selection_valid_loss = selection_valid_loss
                        save_model(selection_model, save_dir, selection_save)
                    else:
                        print("Selection validation loss increased, stopping selection model training")
                        train_selection = False

                if not train_clstm and not train_selection:
                    print("Both early stopping criterion met. Stopping early..")
                    break

            save_checkpoint(epoch + 1, 0, best_placement_valid_loss, best_placement_precision,
                            best_selection_valid_loss, train_clstm, train_selection, save_dir)

    # evaluate on test set
    (placement_test_loss,
     placement_test_acc,
     selection_test_loss,
     selection_test_acc,
     placement_precision) = evaluate(placement_clstm, selection_model, test_iter, 
                                     PLACEMENT_CRITERION, selection_criterion,
                                     device, writer, -1, do_condition, use_transformer)

    # save training summary stats to json file
    # load initial summary
    with open(os.path.join(save_dir, SUMMARY_SAVE), 'r') as f:
        summary_json = json.loads(f.read())
    summary_json = {
        **summary_json,
        'epochs_trained': num_epochs,
        'placement_test_loss': placement_test_loss,
        'placement_test_accuracy': placement_test_acc,
        'placement_test_precision': placement_precision,
        'selection_test_loss': selection_test_loss,
        'selection_test_accuracy': selection_test_acc,
    }

    summary_json = log_training_stats(writer, dataset, summary_json, do_condition, use_transformer)

    with open(os.path.join(save_dir, SUMMARY_SAVE), 'w') as f:
        f.write(json.dumps(summary_json, indent=2))

    # optimize placement thresholds per level (range) which give highest F2 scores on the valid. set
    thresholds = optimize_placement_thresholds(placement_clstm, valid_iter, device)

    with open(os.path.join(save_dir, THRESHOLDS_SAVE), 'w') as f:
        f.write(json.dumps(thresholds, indent=2))
Пример #2
0
    def save(self, tokenizers, output_dirs):
        from train_util import save_model

        save_model(self.encoder, output_dirs.encoder)
        save_model(self.decoder, output_dirs.decoder)
Пример #3
0
def run_training(train_iter,
                 valid_iter,
                 test_iter,
                 chart_type,
                 save_dir,
                 placement_modelname,
                 selection_modelname,
                 device,
                 num_epochs=15):
    # flattened audio feats as input
    placement_input_size = N_MELS * len(N_FFTS)
    if placement_modelname == 'logreg':
        placement_model = baseline.PlacementLogReg(
            input_size=placement_input_size, output_size=2)
    elif placement_modelname == 'mlp':
        placement_model = baseline.PlacementMLP(
            input_size=placement_input_size, output_size=2)

    if selection_modelname == 'ngram':
        ngram_counts = Counter()
        for json_fp, _, _, _ in train_iter.dataset.chart_ids:
            with open(json_fp, 'r') as f:
                attrs = json.loads(f.read())
            for chart_attrs in attrs['charts']:
                notes = [note for _, _, _, note in chart_attrs['notes']]
                for ngram in get_ngrams(notes):
                    ngram_counts[ngram] += 1
        selection_model = baseline.SelectionNGram(ngram_counts)
    elif selection_modelname == 'ngrammlp':
        selection_model = baseline.SelectionNGramMLP(
            SELECTION_VCOAB_SIZES[chart_type],
            SELECTION_INPUT_SIZES[chart_type])

    placement_optim = optim.SGD(placement_model.parameters(), lr=PLACEMENT_LR)
    selection_optim = optim.SGD(selection_model.parameters(), lr=SELECTION_LR)

    placement_save = MODEL_SAVENAMES[placement_modelname]
    selection_save = MODEL_SAVENAMES[selection_modelname]

    best_placement_valid_loss = float('inf')
    best_placement_precision = 0
    best_selection_valid_loss = float('inf')
    train_placement, train_selection = True, True

    print('Starting training..')
    for epoch in trange(num_epochs):
        print('Epoch: {}'.format(epoch))
        epoch_p_loss = 0
        epoch_p_precision = 0
        epoch_s_loss = 0

        for i, batch in enumerate(tqdm(train_iter)):
            with torch.set_grad_enabled(train_placement):
                (placement_loss, placement_acc,
                 placement_precision) = run_placement_batch(
                     placement_model,
                     placement_optim,
                     PLACEMENT_CRITERION,
                     batch,
                     device,
                     do_train=train_placement)

            with torch.set_grad_enabled(train_selection):
                (selection_loss,
                 selection_acc) = run_selection_batch(selection_model,
                                                      selection_optim,
                                                      SELECTION_CRITERION,
                                                      batch,
                                                      device,
                                                      do_train=train_selection)

            epoch_p_loss += placement_loss
            epoch_p_precision += placement_precision
            epoch_s_loss += selection_loss

            if train_placement:
                save_model(placement_model, save_dir, placement_save)
            if train_selection:
                save_model(selection_model, save_dir, selection_save)

        epoch_p_loss = epoch_p_loss / len(train_iter)
        epoch_p_precision = epoch_p_precision / len(train_iter)
        epoch_s_loss = epoch_s_loss / len(train_iter)

        print(
            f'\tAvg. training placement loss per unrolling: {epoch_p_loss:.5f}'
        )
        print(f'\tAvg. training placement precision: {epoch_p_precision:.5f}')
        print(f'\tAvg. training selection loss per frame: {epoch_s_loss:.5f}')

        (placement_valid_loss, placement_valid_acc, selection_valid_loss,
         selection_valid_acc,
         placement_precision) = evaluate(placement_model, selection_model,
                                         valid_iter, PLACEMENT_CRITERION,
                                         SELECTION_CRITERION, device)

        print(
            f'\tAvg. validation placement loss per frame: {placement_valid_loss:.5f}'
        )
        print(
            f'\tAvg. training placement precision: {placement_precision:.5f}')
        print(
            f'\tAvg. validation selection loss per frame: {selection_valid_loss:.5f}'
        )

        better_placement = placement_precision > best_placement_precision
        better_selection = selection_valid_loss < best_selection_valid_loss

        if train_placement:
            if better_placement:
                best_placement_precision = placement_precision
                best_placement_valid_loss = placement_valid_loss
                save_model(placement_clstm, save_dir, placement_save)
            else:
                print(
                    "Placement validation loss increased, stopping CLSTM training"
                )
                train_placement = False

        if train_selection:
            if better_selection:
                best_selection_valid_loss = selection_valid_loss
                save_model(selection_rnn, save_dir, selection_save)
            else:
                print(
                    "Placement validation loss increased, stopping SRNN training"
                )
                train_selection = False

        if not train_placement and not train_selection:
            print("Both early stopping criterion met. Stopping early..")
            break