示例#1
0
def main(parser, fast_dev_run) -> None:
    args = parser.parse_args()
    set_seed(args.seed)
    args.savedir = os.path.join(args.savedir, args.name)
    os.makedirs(args.savedir, exist_ok=True)

    model = get_model(args)

    early_stop_callback = EarlyStopping(
        monitor=args.monitor,
        min_delta=0.0,
        patience=args.patience,
        verbose=True,
        mode=args.metric_mode,
    )

    trainer = Trainer(logger=setup_testube_logger(args),
                      checkpoint_callback=True,
                      early_stop_callback=early_stop_callback,
                      default_root_dir=args.savedir,
                      gpus=args.gpus,
                      distributed_backend=args.distributed_backend,
                      precision=args.precision,
                      amp_level=args.amp_level,
                      max_epochs=args.max_epochs,
                      min_epochs=args.min_epochs,
                      accumulate_grad_batches=args.accumulate_grad_batches,
                      val_percent_check=args.val_percent_check,
                      fast_dev_run=fast_dev_run,
                      num_sanity_val_steps=0)

    ckpt_path = os.path.join(
        trainer.default_root_dir,
        trainer.logger.name,
        f"version_{trainer.logger.version}",
        "checkpoints",
    )
    # initialize Model Checkpoint Saver
    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=args.save_top_k,
        verbose=True,
        monitor=args.monitor,
        period=1,
        mode=args.metric_mode,
    )
    trainer.checkpoint_callback = checkpoint_callback

    trainer.fit(model)
示例#2
0
def train(args):

    set_seed(args.seed)
    args.savedir = os.path.join(args.savedir, args.name)
    os.makedirs(args.savedir, exist_ok=True)

    train_loader, val_loader, test_loaders = get_data_loaders(args)

    model = get_model(args)
    criterion = get_criterion(args)
    optimizer = get_optimizer(model, args)
    scheduler = get_scheduler(optimizer, args)

    logger = create_logger("%s/logfile.log" % args.savedir, args)
    logger.info(model)
    model.cuda()

    torch.save(args, os.path.join(args.savedir, "args.pt"))

    start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf

    if os.path.exists(os.path.join(args.savedir, "checkpoint.pt")):
        checkpoint = torch.load(os.path.join(args.savedir, "checkpoint.pt"))
        start_epoch = checkpoint["epoch"]
        n_no_improve = checkpoint["n_no_improve"]
        best_metric = checkpoint["best_metric"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    logger.info("Training..")
    for i_epoch in range(start_epoch, args.max_epochs):
        train_losses = []
        model.train()
        optimizer.zero_grad()

        for batch in tqdm(train_loader, total=len(train_loader)):
            loss, _, _ = model_forward(i_epoch, model, args, criterion, batch)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            train_losses.append(loss.item())
            loss.backward()
            global_step += 1
            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        model.eval()
        metrics = model_eval(i_epoch, val_loader, model, args, criterion)
        logger.info("Train Loss: {:.4f}".format(np.mean(train_losses)))
        log_metrics("Val", metrics, args, logger)

        tuning_metric = (metrics["micro_f1"]
                         if args.task_type == "multilabel" else metrics["acc"])
        scheduler.step(tuning_metric)
        is_improvement = tuning_metric > best_metric
        if is_improvement:
            best_metric = tuning_metric
            n_no_improve = 0
        else:
            n_no_improve += 1

        save_checkpoint(
            {
                "epoch": i_epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "n_no_improve": n_no_improve,
                "best_metric": best_metric,
            },
            is_improvement,
            args.savedir,
        )

        if n_no_improve >= args.patience:
            logger.info("No improvement. Breaking out of loop.")
            break

    load_checkpoint(model, os.path.join(args.savedir, "model_best.pt"))
    model.eval()
    for test_name, test_loader in test_loaders.items():
        model_predict(np.inf, test_loader, model, args, criterion)
        test_metrics = model_eval(np.inf,
                                  test_loader,
                                  model,
                                  args,
                                  criterion,
                                  store_preds=True)
        log_metrics(f"Test - {test_name}", test_metrics, args, logger)
示例#3
0
def train(args):

    set_seed(args.seed)
    args.savedir = os.path.join(args.savedir, args.name)
    os.makedirs(args.savedir, exist_ok=True)

    train_loader, val_loader, test_loaders = get_data_loaders(args)

    model = get_model(args)
    criterion = get_criterion(args)
    optimizer = get_optimizer(model, args)
    scheduler = get_scheduler(optimizer, args)

    logger = create_logger("%s/logfile.log" % args.savedir, args)
    logger.info(model)
    model.cuda()
    # if args.multiGPU:
    #     model = nn.DataParallel(model)
    torch.save(args, os.path.join(args.savedir, "args.pt"))

    start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf

    if os.path.exists(os.path.join(args.savedir, "checkpoint.pt")):
        checkpoint = torch.load(os.path.join(args.savedir, "checkpoint.pt"))
        start_epoch = checkpoint["epoch"]
        n_no_improve = checkpoint["n_no_improve"]
        best_metric = checkpoint["best_metric"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    logger.info("Training..")
    if APEX_AVAILABLE and args.fp16:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level="O2",
            keep_batchnorm_fp32=True,
            loss_scale="dynamic",
        )
    for i_epoch in range(start_epoch, args.max_epochs):
        train_losses = []
        model.train()
        optimizer.zero_grad()
        logger.info(f"total data:{len(train_loader)}")
        train_batch_start = time.time()
        for batch in tqdm(train_loader, total=(len(train_loader))):
            loss, _, _ = model_forward(i_epoch, model, args, criterion, batch)
            if args.multiGPU:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            train_losses.append(loss.item())
            if APEX_AVAILABLE and args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            global_step += 1
            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
        train_batch_end = time.time()
        logger.info(
            f"EPOCH: {i_epoch}, Train Loss: {np.mean(train_losses):.4f}, time: {(train_batch_end - train_batch_start)/60:.1f}mins"
        )
        eval_start = time.time()
        model.eval()
        metrics = model_eval(i_epoch, val_loader, model, args, criterion)
        eval_end = time.time()
        log_metrics("Val", metrics, args, logger)
        if args.task_type == "multilabel":
            tuning_metric = metrics["micro_f1"]
            logger.info(
                f"Val acc {tuning_metric}, time: {(eval_end - eval_start)/60:.2f}mins"
            )
        else:
            tuning_metric = metrics["acc"]
            logger.info(
                f'Val acc {metrics["acc"]:.3f}, precision {metrics["prec"]:.3f}, recall {metrics["recall"]:.3f}, time: {(eval_end - eval_start)/60:.2f}mins '
            )

        scheduler.step(tuning_metric)
        is_improvement = tuning_metric > best_metric
        if is_improvement:
            best_metric = tuning_metric
            n_no_improve = 0
        else:
            n_no_improve += 1

        save_checkpoint(
            {
                "epoch": i_epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "n_no_improve": n_no_improve,
                "best_metric": best_metric,
            },
            is_improvement,
            args.savedir,
        )
        if not args.no_cuda:
            torch.cuda.empty_cache()
        if n_no_improve >= args.patience:
            logger.info("No improvement. Breaking out of loop.")
            break

    load_checkpoint(model, os.path.join(args.savedir, "model_best.pt"))
    model.eval()
    for test_name, test_loader in test_loaders.items():
        test_metrics = model_eval(np.inf,
                                  test_loader,
                                  model,
                                  args,
                                  criterion,
                                  store_preds=True)
        log_metrics(f"Test - {test_name}", test_metrics, args, logger)