示例#1
0
def visualize(vectorizer, data=None):
    data = get_data(p.VAL_FILE,
                    vectorizer,
                    with_oov=p.POINTER_GEN,
                    aspect_file=p.ASPECT_FILE) if data is None else data
    model = new_model(vectorizer, data.dataset.aspects).eval()
    with open(p.MODEL_FILE, 'rb') as modelfile:
        model.load_state_dict(pkl.load(modelfile))
    batch = data[271:271 + p.DECODING_BATCH_SIZE]
    aspect_results = summarize(batch, model, beam_size=p.BEAM_SIZE)
    print_batch(batch, [r[0] for r in aspect_results], vectorizer,
                model.aspects)
    vis(p.VISUALIZATION_FILE,
        batch, [r[0] for r in aspect_results],
        vectorizer,
        model.aspects,
        0,
        0,
        pointer_gen=p.POINTER_GEN)
示例#2
0
def evaluate(vectorizer, data=None):
    data = get_data(p.VAL_FILE,
                    vectorizer,
                    with_oov=p.POINTER_GEN,
                    aspect_file=p.ASPECT_FILE) if data is None else data
    model = new_model(vectorizer, data.dataset.aspects).eval()
    with open(p.MODEL_FILE, 'rb') as modelfile:
        model.load_state_dict(pkl.load(modelfile))
    text_path = p.TEXT_PATH
    produce_summary_files(data,
                          p.DECODING_BATCH_SIZE,
                          vectorizer,
                          model,
                          text_path,
                          beam_size=p.BEAM_SIZE,
                          max_num_batch=None)
    #run_rouge_2(save_to=os.path.join(p.CHECKPOINT_PATH, 'rouge_scores.txt') if p.CHECKPOINT_PATH is not None else None)
    run_rouge_1(os.path.join(text_path, 'system'),
                os.path.join(text_path, 'reference'),
                save_to=os.path.join(text_path, 'rouge_scores.txt'))
示例#3
0
def train(vectorizer, data=None, val=None):
    data = get_data(p.DATA_FILE,
                    vectorizer,
                    with_oov=p.POINTER_GEN,
                    aspect_file=p.ASPECT_FILE) if data is None else data
    val = get_data(p.VAL_FILE,
                   vectorizer,
                   with_oov=p.POINTER_GEN,
                   aspect_file=p.ASPECT_FILE) if val is None else val

    if p.CONTINUE_FROM_CHECKPOINT:
        # check if all of the proper files exist
        if not TrainingTracker.valid_checkpoint(p.CHECKPOINT_PATH):
            print(
                "Cannot continue from checkpoint in \"" + p.CHECKPOINT_PATH +
                "\" because not all of the proper files exist; restarting training."
            )
            p.CONTINUE_FROM_CHECKPOINT = False
            print("Saving parameters to file again.")
            p.save_params(
                os.path.join(p.CHECKPOINT_PATH, 'train_param_info.txt'))
        else:
            print("Loading from the last checkpoint in \"" +
                  p.CHECKPOINT_PATH +
                  "\"; model parameters must match the saved model state.")
            if p.NEW_EPOCH:
                print("Starting from a new epoch.")
            else:
                print(
                    "Continuing from the same place in the epoch; this expects the same datafile."
                )

    model = new_model(vectorizer, data.dataset.aspects).train()
    if p.CONTINUE_FROM_CHECKPOINT:
        TrainingTracker.load_model_state_(model, p.CHECKPOINT_PATH)


#     optimizer = torch.optim.Adam(model.parameters(), lr=p.LEARNING_RATE)
    optimizer = torch.optim.Adagrad(
        model.parameters(),
        lr=p.LEARNING_RATE,
        initial_accumulator_value=p.INITIAL_ACCUMULATOR_VALUE,
    )
    if p.CONTINUE_FROM_CHECKPOINT:
        TrainingTracker.load_optimizer_state(optimizer, p.CHECKPOINT_PATH)

    model_manip = ModelManipulator(model,
                                   optimizer,
                                   aspect_summarizer_loss,
                                   aspect_summarizer_error,
                                   grad_mod=clip_grad_norm,
                                   no_nan_grad=True)
    with torch.autograd.set_detect_anomaly(p.DETECT_ANOMALY):
        train_stats, val_stats = model_manip.train(
            data,
            p.BATCH_SIZE,
            p.NUM_EPOCHS,
            dataset_val=val,
            stats_every=10,
            verbose_every=10,
            checkpoint_every=10,
            checkpoint_path=p.CHECKPOINT_PATH,
            restart=not p.CONTINUE_FROM_CHECKPOINT,
            new_epoch=p.NEW_EPOCH,
            max_steps=p.MAX_TRAINING_STEPS,
            save_whole_model=False)
    if p.CHECKPOINT_PATH is not None:
        plot_checkpoint(p.CHECKPOINT_PATH,
                        figure_name='plot',
                        show=False,
                        average_over=p.AVERAGE_OVER)
    else:
        plot_learning_curves(training_values=train_stats,
                             validation_values=val_stats,
                             figure_name=os.path.join(p.TRAINING_PLOTS_PATH,
                                                      'plot'),
                             show=False,
                             average_over=p.AVERAGE_OVER)
示例#4
0
         with_coverage=True),
    dict(max_training_steps=233000,
         max_text_length=400,
         max_summary_length=100,
         with_coverage=True),
]

if __name__ == '__main__':
    checkpoint_path = os.path.join(CHECKPOINTS_FOLDER, 'checkpoint')
    vectorizer = setup(checkpoint_path=checkpoint_path,
                       device=DEVICE,
                       pointer_gen=POINTER_GEN,
                       use_transformer=USE_TRANSFORMER,
                       mode='train')
    data = get_data(p.DATA_FILE,
                    vectorizer,
                    with_oov=p.POINTER_GEN,
                    aspect_file=p.ASPECT_FILE)
    val = get_data(p.VAL_FILE,
                   vectorizer,
                   with_oov=p.POINTER_GEN,
                   aspect_file=p.ASPECT_FILE)
    for i, params in enumerate(sections):
        print(('starting section %i:\n' % (i + 1)) + str(params))
        set_params(**params, continue_from_checkpoint=True)
        try:
            train(vectorizer, data=data, val=val)
            subprocess.run([
                'cp', '-r', checkpoint_path,
                os.path.join(CHECKPOINTS_FOLDER, 'checkpoint%i' % (i + 1))
            ])
        except StopEarlyWithoutSavingException: