Esempio n. 1
0
    saver = ModelSaver(model_path, init_val=0)
    offset_ep = 1
    offset_ep = saver.load_ckpt(model, optimizer, device)
    if offset_ep > CONFIG.hyperparam.misc.max_epoch:
        raise RuntimeError(
            "trying to restart at epoch {} while max training is set to {} \
            epochs".format(offset_ep, CONFIG.hyperparam.misc.max_epoch))
    ########################################################

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    if CONFIG.use_wandb:
        wandb.watch(model)

    ################### training loop #####################
    for ep in range(offset_ep - 1, CONFIG.hyperparam.misc.max_epoch):
        print("global {} | begin training for epoch {}".format(
            global_timer, ep + 1))
        train_epoch(train_loader, model, optimizer, device, ep, CONFIG)
        print(
            "global {} | done with training for epoch {}, beginning validation"
            .format(global_timer, ep + 1))
        metrics = validate(val_loader, model, tokenizer, evaluator, device,
                           CONFIG)
        if "METEOR" in metrics.keys():
            saver.save_ckpt_if_best(model, optimizer, metrics["METEOR"])
        print("global {} | end epoch {}".format(global_timer, ep + 1))
    print("done training!")
    #######################################################
Esempio n. 2
0
        offset_ep = saver.load_ckpt(model, optimizer, device)
        if offset_ep > CONFIG.max_epoch:
            logging.error("trying to restart at epoch {} while max training is set to {} epochs".format(offset_ep, CONFIG.max_epoch))
            sys.exit(1)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.padidx)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    logging.debug("done!")

    logging.debug("loading evaluator...")
    #evaluator = BleuComputer()
    evaluator = NLGEval(metrics_to_omit=["METEOR"]) # meteor has problems, so omit
    logging.debug("done!")

    for ep in range(offset_ep-1, CONFIG.max_epoch):
        logging.info("global {} | begin training for epoch {}".format(global_timer, ep+1))
        train_epoch(train_loader, model, optimizer, criterion, device, tb_logger, ep)
        logging.info("global {} | done with training for epoch {}, beginning validation".format(global_timer, ep+1))
        metrics = validate(val_loader, model, tokenizer, evaluator, device)
        for key, val in metrics.items():
            tb_logger.add_scalar("metrics/{}".format(key), val, ep+1)
        if "Bleu_4" in metrics.keys():
            saver.save_ckpt_if_best(model, optimizer, metrics["Bleu_4"])
        logging.info("global {} | end epoch {}".format(global_timer, ep+1))
    logging.info("done training!!")