コード例 #1
0
ファイル: test_runner.py プロジェクト: Randl/pytorch-tools
def test_Timer_callback():
    runner = Runner(
        model=TestModel,
        optimizer=TestOptimizer,
        criterion=TestCriterion,
        metrics=TestMetric,
        callbacks=pt_clb.Timer(),
    )
    runner.fit(TestLoader, epochs=2)
コード例 #2
0
ファイル: test_runner.py プロジェクト: bonlime/pytorch-tools
def test_rank_zero_only():
    """check that decorator disables come callbacks"""
    os.environ["RANK"] = "0"
    # check that wrapping instance work
    timer = pt_clb.rank_zero_only(pt_clb.Timer())
    assert hasattr(timer, "timer")

    os.environ["RANK"] = "1"
    # check that wrapping class also works
    timer = pt_clb.rank_zero_only(pt_clb.Timer)()
    assert not hasattr(timer, "timer")
コード例 #3
0
        criterion=TEST_CRITERION,
        metrics=TEST_METRIC,
        gradient_clip_val=1.0
    )
    runner.fit(TEST_LOADER, epochs=2)


# We only test that callbacks don't crash NOT that they do what they should do
TMP_PATH = "/tmp/pt_tools2/"
os.makedirs(TMP_PATH, exist_ok=True)


@pytest.mark.parametrize(
    "callback",
    [
        pt_clb.Timer(),
        pt_clb.ReduceLROnPlateau(),
        pt_clb.CheckpointSaver(TMP_PATH, save_name="model.chpn"),
        pt_clb.CheckpointSaver(
            TMP_PATH, save_name="model.chpn", monitor=TEST_METRIC.name, mode="max"
        ),
        pt_clb.TensorBoard(log_dir=TMP_PATH),
        pt_clb.TensorBoardWithCM(log_dir=TMP_PATH),
        pt_clb.ConsoleLogger(),
        pt_clb.FileLogger(TMP_PATH),
        pt_clb.Mixup(0.2, NUM_CLASSES),
        pt_clb.Cutmix(1.0, NUM_CLASSES),
        pt_clb.ScheduledDropout(),
    ],
)
def test_callback(callback):
コード例 #4
0
def main():
    FLAGS = parse_args()
    print(FLAGS)
    pt.utils.misc.set_random_seed(42)  # fix all seeds
    ## dump config
    os.makedirs(FLAGS.outdir, exist_ok=True)
    yaml.dump(vars(FLAGS), open(FLAGS.outdir + '/config.yaml', 'w'))

    ## get dataloaders
    if FLAGS.train_tta:
        FLAGS.bs //= 4  # account for later augmentations to avoid OOM
    train_dtld, val_dtld = get_dataloaders(FLAGS.datasets, FLAGS.augmentation,
                                           FLAGS.bs, FLAGS.size,
                                           FLAGS.val_size,
                                           FLAGS.buildings_only)

    ## get model and optimizer
    model = MODEL_FROM_NAME[FLAGS.segm_arch](FLAGS.arch,
                                             **FLAGS.model_params).cuda()
    if FLAGS.train_tta:
        # idea from https://arxiv.org/pdf/2002.09024.pdf paper
        model = pt.tta_wrapper.TTA(model,
                                   segm=True,
                                   h_flip=True,
                                   rotation=[90],
                                   merge="max")
        model.encoder = model.model.encoder
        model.decoder = model.model.decoder
    optimizer = optimizer_from_name(FLAGS.optim)(
        model.parameters(),
        lr=FLAGS.lr,
        weight_decay=FLAGS.
        weight_decay,  # **FLAGS.optim_params TODO: add additional optim params if needed
    )
    if FLAGS.lookahead:
        optimizer = pt.optim.Lookahead(optimizer)

    if FLAGS.resume:
        checkpoint = torch.load(
            FLAGS.resume, map_location=lambda storage, loc: storage.cuda())
        model.load_state_dict(checkpoint["state_dict"], strict=False)
    num_params = pt.utils.misc.count_parameters(model)[0]
    print(f"Number of parameters: {num_params / 1e6:.02f}M")

    ## train on fp16 by default
    model, optimizer = apex.amp.initialize(model,
                                           optimizer,
                                           opt_level="O1",
                                           verbosity=0,
                                           loss_scale=1024)

    ## get loss. fixed for now.
    bce_loss = TargetWrapper(
        pt.losses.CrossEntropyLoss(mode="binary").cuda(), "mask")
    bce_loss.name = "BCE"
    loss = criterion_from_list(FLAGS.criterion).cuda()
    # loss = 0.5 * pt.losses.CrossEntropyLoss(mode="binary", weight=[5]).cuda()
    print("Loss for this run is: ", loss)
    ## get runner
    sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(FLAGS.phases)
    runner = pt.fit_wrapper.Runner(
        model,
        optimizer,
        criterion=loss,
        callbacks=[
            pt_clb.Timer(),
            pt_clb.ConsoleLogger(),
            pt_clb.FileLogger(FLAGS.outdir),
            pt_clb.SegmCutmix(1, 1) if FLAGS.cutmix else NoClb(),
            pt_clb.CheckpointSaver(FLAGS.outdir, save_name="model.chpn"),
            sheduler,
            PredictViewer(FLAGS.outdir, num_images=8),
            ScheduledDropout(FLAGS.dropout, FLAGS.dropout_epochs)
            if FLAGS.dropout else NoClb()
        ],
        metrics=[
            bce_loss,
            TargetWrapper(
                pt.metrics.JaccardScore(mode="binary").cuda(), "mask"),
            TargetWrapper(ThrJaccardScore(thr=0.5), "mask"),
            TargetWrapper(BalancedAccuracy2(balanced=False), "mask"),
        ],
    )

    if FLAGS.decoder_warmup_epochs > 0:
        ## freeze encoder
        for p in model.encoder.parameters():
            p.requires_grad = False
        runner.fit(
            train_dtld,
            val_loader=val_dtld,
            epochs=FLAGS.decoder_warmup_epochs,
            steps_per_epoch=50 if FLAGS.short_epoch else None,
            val_steps=50 if FLAGS.short_epoch else None,
        )

        ## unfreeze all
        for p in model.parameters():
            p.requires_grad = True

        # need to init again to avoid nan's in loss
        optimizer = optimizer_from_name(FLAGS.optim)(
            model.parameters(),
            lr=FLAGS.lr,
            weight_decay=FLAGS.
            weight_decay,  # **FLAGS.optim_params TODO: add additional optim params if needed
        )
        model, optimizer = apex.amp.initialize(model,
                                               optimizer,
                                               opt_level="O1",
                                               verbosity=0,
                                               loss_scale=2048)
        runner.state.model = model
        runner.state.optimizer = optimizer

    runner.fit(
        train_dtld,
        val_loader=val_dtld,
        start_epoch=FLAGS.decoder_warmup_epochs,
        epochs=sheduler.tot_epochs,
        steps_per_epoch=50 if FLAGS.short_epoch else None,
        val_steps=50 if FLAGS.short_epoch else None,
    )
コード例 #5
0
def main():
    # Get config for this run
    hparams = parse_args()

    # Setup logger
    config = {
        "handlers": [
            {
                "sink": sys.stdout,
                "format": "{time:[MM-DD HH:mm]} - {message}"
            },
            {
                "sink": f"{hparams.outdir}/logs.txt",
                "format": "{time:[MM-DD HH:mm]} - {message}"
            },
        ],
    }
    logger.configure(**config)
    logger.info(f"Parameters used for training: {hparams}")

    # Fix seeds for reprodusability
    pt.utils.misc.set_random_seed(hparams.seed)

    # Save config
    os.makedirs(hparams.outdir, exist_ok=True)
    yaml.dump(vars(hparams), open(hparams.outdir + "/config.yaml", "w"))

    # Get model
    model = Model(arch=hparams.arch,
                  model_params=hparams.model_params,
                  embedding_size=hparams.embedding_size,
                  pooling=hparams.pooling).cuda()

    # Get loss
    # loss = LOSS_FROM_NAME[hparams.criterion](in_features=hparams.embedding_size, **hparams.criterion_params).cuda()
    loss = LOSS_FROM_NAME["cross_entropy"].cuda()
    logger.info(f"Loss for this run is: {loss}")

    if hparams.resume:
        checkpoint = torch.load(
            hparams.resume, map_location=lambda storage, loc: storage.cuda())
        model.load_state_dict(checkpoint["state_dict"], strict=True)
        loss.load_state_dict(checkpoint["loss"], strict=True)

    if hparams.freeze_bn:
        freeze_batch_norm(model)

    # Get optimizer
    # optim_params = pt.utils.misc.filter_bn_from_wd(model)
    optim_params = list(loss.parameters()) + list(
        model.parameters())  # add loss params
    optimizer = optimizer_from_name(hparams.optim)(
        optim_params, lr=0, weight_decay=hparams.weight_decay, amsgrad=True)

    num_params = pt.utils.misc.count_parameters(model)[0]
    logger.info(f"Model size: {num_params / 1e6:.02f}M")
    # logger.info(model)

    # Scheduler is an advanced way of planning experiment
    sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(hparams.phases)

    # Save logs
    TB_callback = pt_clb.TensorBoard(hparams.outdir, log_every=20)

    # Get dataloaders
    train_loader, val_loader, val_indexes = get_dataloaders(
        root=hparams.root,
        augmentation=hparams.augmentation,
        size=hparams.size,
        val_size=hparams.val_size,
        batch_size=hparams.batch_size,
        workers=hparams.workers,
    )

    # Load validation query / gallery split and resort it according to indexes from sampler
    df_val = pd.read_csv(os.path.join(hparams.root, "train_val.csv"))
    df_val = df_val[df_val["is_train"].astype(np.bool) == False]
    val_is_query = df_val.is_query.values[val_indexes].astype(np.bool)

    logger.info(f"Start training")
    # Init runner
    runner = pt.fit_wrapper.Runner(
        model,
        optimizer,
        criterion=loss,
        callbacks=[
            # pt_clb.BatchMetrics([pt.metrics.Accuracy(topk=1)]),
            ContestMetricsCallback(
                is_query=val_is_query[:1280] if hparams.debug else val_is_query
            ),
            pt_clb.Timer(),
            pt_clb.ConsoleLogger(),
            pt_clb.FileLogger(),
            TB_callback,
            CheckpointSaver(hparams.outdir,
                            save_name="model.chpn",
                            monitor="target",
                            mode="max"),
            CheckpointSaver(hparams.outdir,
                            save_name="model_mapr.chpn",
                            monitor="mAP@R",
                            mode="max"),
            CheckpointSaver(hparams.outdir, save_name="model_loss.chpn"),
            sheduler,
            # EMA must go after other checkpoints
            pt_clb.ModelEma(model, hparams.ema_decay)
            if hparams.ema_decay else pt_clb.Callback(),
        ],
        use_fp16=hparams.
        use_fp16,  # use mixed precision by default.  # hparams.opt_level != "O0",
    )

    if hparams.head_warmup_epochs > 0:
        #Freeze model
        for p in model.parameters():
            p.requires_grad = False

        runner.fit(
            train_loader,
            # val_loader=val_loader,
            epochs=hparams.head_warmup_epochs,
            steps_per_epoch=20 if hparams.debug else None,
            # val_steps=20 if hparams.debug else None,
        )

        # Unfreeze model
        for p in model.parameters():
            p.requires_grad = True

        if hparams.freeze_bn:
            freeze_batch_norm(model)

        # Re-init to avoid nan's in loss
        optim_params = list(loss.parameters()) + list(model.parameters())

        optimizer = optimizer_from_name(hparams.optim)(
            optim_params,
            lr=0,
            weight_decay=hparams.weight_decay,
            amsgrad=True)

        runner.state.model = model
        runner.state.optimizer = optimizer
        runner.state.criterion = loss

    # Train
    runner.fit(
        train_loader,
        # val_loader=val_loader,
        start_epoch=hparams.head_warmup_epochs,
        epochs=sheduler.tot_epochs,
        steps_per_epoch=20 if hparams.debug else None,
        # val_steps=20 if hparams.debug else None,
    )

    logger.info(f"Loading best model")
    checkpoint = torch.load(os.path.join(hparams.outdir, f"model.chpn"))
    model.load_state_dict(checkpoint["state_dict"], strict=True)
    # runner.state.model = model
    # loss.load_state_dict(checkpoint["loss"], strict=True)

    # Evaluate
    _, [acc1, map10, target, mapR] = runner.evaluate(
        val_loader,
        steps=20 if hparams.debug else None,
    )

    logger.info(
        f"Val: Acc@1 {acc1:0.5f}, mAP@10 {map10:0.5f}, Target {target:0.5f}, mAP@R {mapR:0.5f}"
    )

    # Save params used for training and final metrics into separate TensorBoard file
    metric_dict = {
        "hparam/Acc@1": acc1,
        "hparam/mAP@10": map10,
        "hparam/mAP@R": target,
        "hparam/Target": mapR,
    }

    # Convert all lists / dicts to avoid TB error
    hparams.phases
    hparams.phases = str(hparams.phases)
    hparams.model_params = str(hparams.model_params)
    hparams.criterion_params = str(hparams.criterion_params)

    with pt.utils.tensorboard.CorrectedSummaryWriter(hparams.outdir) as writer:
        writer.add_hparams(hparam_dict=vars(hparams), metric_dict=metric_dict)
コード例 #6
0
ファイル: train.py プロジェクト: zakajd/metrics-comparison
def main():

    # Get config for this run
    hparams = parse_args()

    # Setup logger
    config = {
        "handlers": [
            {
                "sink": sys.stdout,
                "format": "{time:[MM-DD HH:mm]} - {message}"
            },
            {
                "sink": f"{hparams.outdir}/logs.txt",
                "format": "{time:[MM-DD HH:mm]} - {message}"
            },
        ],
    }
    logger.configure(**config)
    # Use print instead of logger to have alphabetic order.
    logger.info(f"Parameters used for training: {vars(hparams)}")

    # Fix all seeds for reprodusability
    pt.utils.misc.set_random_seed(hparams.seed)

    # Save config
    os.makedirs(hparams.outdir, exist_ok=True)
    yaml.dump(vars(hparams), open(hparams.outdir + "/config.yaml", "w"))

    # Get models and optimizers
    model = MODEL_FROM_NAME[hparams.model](**hparams.model_params).cuda()
    logger.info(
        f"Model size: {pt.utils.misc.count_parameters(model)[0] / 1e6:.02f}M")

    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=hparams.weight_decay,
                                 amsgrad=True)  # Get LR from phases later

    # Get loss
    loss = criterion_from_list(hparams.criterion).cuda()

    # Init per-image metrics and add names
    metrics = metrics_from_list(hparams.metrics, reduction='mean')
    logger.info(f"Metrics: {[m.name for m in metrics]}")

    # Init feature metrics and add names
    feature_metrics = []
    feature_extractor = "vgg16"
    for name in hparams.feature_metrics:
        metric = copy.copy(METRIC_FROM_NAME[name])
        metric.name = f"{name}_{feature_extractor}"
        feature_metrics.append(metric)

    # Scheduler is an advanced way of planning experiment
    sheduler = pt_clb.PhasesScheduler(hparams.phases)

    save_name = "model_{monitor}.chpn"
    # Init train loop
    runner = pt.fit_wrapper.Runner(
        model=model,
        optimizer=optimizer,
        criterion=loss,
        callbacks=[
            pt_clb.Timer(),
            clb.FeatureLoaderMetrics(metrics=feature_metrics,
                                     feature_extractor="vgg16"),
            pt_clb.BatchMetrics(metrics=metrics),
            clb.ConsoleLogger(metrics=["ssim", "psnr"]),
            clb.TensorBoard(hparams.outdir, log_every=40, num_images=2),

            # List of CheckpointSavers, one per metric
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='loss',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='psnr',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='ssim',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='ms-ssim',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='gmsd',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='ms-gmsd',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='ms-gmsdc',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='fsim',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='fsimc',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='vsi',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='mdsi',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='vifp',
                                mode='max',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='content_vgg16_ap',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='style_vgg16',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='lpips',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='dists',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='brisque',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='is_metric_vgg16',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='is_vgg16',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='kid_vgg16',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='fid_vgg16',
                                mode='min',
                                verbose=False),
            clb.CheckpointSaver(hparams.outdir,
                                save_name=save_name,
                                monitor='msid_vgg16',
                                mode='min',
                                verbose=False),
            sheduler,
        ],
    )

    # Get dataloaders
    transform = get_aug(aug_type=hparams.aug_type,
                        task=hparams.task,
                        dataset=hparams.train_dataset,
                        size=hparams.size)
    train_loader = get_dataloader(dataset=hparams.train_dataset,
                                  train=True,
                                  transform=transform,
                                  batch_size=hparams.batch_size)

    transform = get_aug(aug_type="val",
                        task=hparams.task,
                        dataset=hparams.val_dataset,
                        size=hparams.size)
    val_loader = get_dataloader(dataset=hparams.val_dataset,
                                train=False,
                                transform=transform,
                                batch_size=hparams.batch_size)

    # Train
    runner.fit(
        train_loader,
        epochs=sheduler.tot_epochs,
        val_loader=val_loader,
        steps_per_epoch=2 if hparams.debug else None,
        val_steps=2 if hparams.debug else None,
    )

    logger.info("Finished training!")
コード例 #7
0
def main():

    ## get config for this run
    FLAGS = parse_args()
    os.makedirs(FLAGS.outdir, exist_ok=True)
    config = {
        "handlers": [
            {
                "sink": sys.stdout,
                "format": "{time:[MM-DD HH:mm:ss]} - {message}"
            },
            {
                "sink": f"{FLAGS.outdir}/logs.txt",
                "format": "{time:[MM-DD HH:mm:ss]} - {message}"
            },
        ],
    }
    if FLAGS.is_master:
        logger.configure(**config)
        ## dump config and diff for reproducibility
        yaml.dump(vars(FLAGS), open(FLAGS.outdir + "/config.yaml", "w"))
        kwargs = {"universal_newlines": True, "stdout": subprocess.PIPE}
        with open(FLAGS.outdir + "/commit_hash.txt", "w") as fp:
            fp.write(
                subprocess.run(["git", "rev-parse", "--short", "HEAD"],
                               **kwargs).stdout)
        with open(FLAGS.outdir + "/diff.txt", "w") as fp:
            fp.write(subprocess.run(["git", "diff"], **kwargs).stdout)
    else:
        logger.configure(handlers=[])
    logger.info(FLAGS)

    ## makes it slightly faster
    cudnn.benchmark = True
    if FLAGS.deterministic:
        pt.utils.misc.set_random_seed(42)  # fix all seeds

    ## setup distributed
    if FLAGS.distributed:
        logger.info("Distributed initializing process group")
        torch.cuda.set_device(FLAGS.local_rank)
        dist.init_process_group(backend="nccl",
                                init_method="env://",
                                world_size=FLAGS.world_size)

    ## get dataloaders
    train_loader = DaliLoader(True, FLAGS.batch_size, FLAGS.workers,
                              FLAGS.size)
    val_loader = DaliLoader(False, FLAGS.batch_size, FLAGS.workers, FLAGS.size)

    ## get model
    logger.info(f"=> Creating model '{FLAGS.arch}'")
    model = det_models.__dict__[FLAGS.arch](**FLAGS.model_params)
    if FLAGS.weight_standardization:
        model = pt.modules.weight_standartization.conv_to_ws_conv(model)
    model = model.cuda()

    ## get optimizer
    # want to filter BN from weight decay by default. It never hurts
    optim_params = pt.utils.misc.filter_bn_from_wd(model)
    # start with 0 lr. Scheduler will change this later
    optimizer = optimizer_from_name(FLAGS.optim)(
        optim_params,
        lr=0,
        weight_decay=FLAGS.weight_decay,
        **FLAGS.optim_params)
    if FLAGS.lookahead:
        optimizer = pt.optim.Lookahead(optimizer, la_alpha=0.5)

    ## load weights from previous run if given
    if FLAGS.resume:
        checkpoint = torch.load(
            FLAGS.resume,
            map_location=lambda s, loc: s.cuda())  # map for multi-gpu
        model.load_state_dict(checkpoint["state_dict"])  # strict=False
        FLAGS.start_epoch = checkpoint["epoch"]
        try:
            optimizer.load_state_dict(checkpoint["optimizer"])
        except:  # may raise an error if another optimzer was used or no optimizer in state dict
            logger.info("Failed to load state dict into optimizer")

    # Important to create EMA Callback after cuda() and AMP but before DDP wrapper
    ema_clb = pt_clb.ModelEma(
        model, FLAGS.ema_decay) if FLAGS.ema_decay else NoClbk()
    if FLAGS.distributed:
        model = DDP(model, delay_allreduce=True)

    ## define loss function (criterion)
    anchors = pt.utils.box.generate_anchors_boxes(FLAGS.size)[0]
    # script loss to lower memory consumption and make it faster
    # as of 1.5 it does run but loss doesn't decrease for some reason
    # FIXME: uncomment after 1.6
    criterion = torch.jit.script(DetectionLoss(anchors).cuda())
    # criterion = DetectionLoss(anchors).cuda()

    ## load COCO (needed for evaluation)
    val_coco_api = COCO("data/annotations/instances_val2017.json")

    model_saver = (pt_clb.CheckpointSaver(FLAGS.outdir, save_name="model.chpn")
                   if FLAGS.is_master else NoClbk())
    sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(FLAGS.phases)
    # common callbacks
    callbacks = [
        pt_clb.StateReduce(),  # MUST go first
        sheduler,
        pt_clb.Mixup(FLAGS.mixup, 1000) if FLAGS.mixup else NoClbk(),
        pt_clb.Cutmix(FLAGS.cutmix, 1000) if FLAGS.cutmix else NoClbk(),
        model_saver,  # need to have CheckpointSaver before EMA so moving it here
        ema_clb,  # ModelEMA MUST go after checkpoint saver to work, otherwise it would save main model instead of EMA
        CocoEvalClbTB(FLAGS.outdir, val_coco_api, anchors),
    ]
    if FLAGS.is_master:  # callback for master process
        master_callbacks = [
            pt_clb.Timer(),
            pt_clb.ConsoleLogger(),
            pt_clb.FileLogger(FLAGS.outdir, logger=logger),
        ]
        callbacks.extend(master_callbacks)

    runner = pt.fit_wrapper.Runner(
        model,
        optimizer,
        criterion,
        # metrics=[pt.metrics.Accuracy(), pt.metrics.Accuracy(5)],
        callbacks=callbacks,
        use_fp16=FLAGS.opt_level != "O0",
    )
    if FLAGS.evaluate:
        return None, (42, 42)
        return runner.evaluate(val_loader)

    runner.fit(
        train_loader,
        steps_per_epoch=(None, 10)[FLAGS.short_epoch],
        val_loader=val_loader,
        # val_steps=(None, 20)[FLAGS.short_epoch],
        epochs=sheduler.tot_epochs,
        # start_epoch=FLAGS.start_epoch, # TODO: maybe want to continue from epoch
    )

    # TODO: maybe return best loss?
    return runner.state.val_loss.avg, (
        0, 0)  # [m.avg for m in runner.state.val_metrics]
コード例 #8
0
def main():
    # Setup logger
    config = {
    "handlers": [ 
        {"sink": sys.stdout, "format": "{time:[MM-DD HH:mm:ss]} - {message}"},
        {"sink": f"{hparams.outdir}/logs.txt", "format": "{time:[MM-DD HH:mm:ss]} - {message}"},
        ],
    }
    logger.configure(**config)

    # Get config for this run
    hparams = parse_args()
    logger.info(f"Parameters used for training: {hparams}")

    # Fix seeds for reprodusability
    pt.utils.misc.set_random_seed(hparams.seed) 

    ## Save config and Git diff (don't know how to do it without subprocess)
    os.makedirs(hparams.outdir, exist_ok=True)
    yaml.dump(vars(hparams), open(hparams.outdir + '/config.yaml', 'w'))
    kwargs = {"universal_newlines": True, "stdout": subprocess.PIPE}
    with open(hparams.outdir + '/commit_hash.txt', 'w') as f:
        f.write(subprocess.run(["git", "rev-parse", "--short", "HEAD"], **kwargs).stdout)
    with open(hparams.outdir + '/diff.txt', 'w') as f:
        f.write(subprocess.run(["git", "diff"], **kwargs).stdout)

    ## Get dataloaders
    train_loader, val_loader = get_dataloaders(
        root=hparams.root, 
        augmentation=hparams.augmentation,
        fold=hparams.fold,
        pos_weight=hparams.pos_weight,
        batch_size=hparams.batch_size,
        size=hparams.size, 
        val_size=hparams.val_size,
        workers=hparams.workers
    )

    # Get model and optimizer
    model = MODEL_FROM_NAME[hparams.segm_arch](hparams.backbone, **hparams.model_params).cuda()
    optimizer = optimizer_from_name(hparams.optim)(
        model.parameters(), # Get LR from phases later
        weight_decay=hparams.weight_decay
    )

    # Convert all Conv2D -> WS_Conv2d if needed
    if hparams.ws:
        model = pt.modules.weight_standartization.conv_to_ws_conv(model).cuda()

    # Load weights if needed
    if hparams.resume:
        checkpoint = torch.load(hparams.resume, map_location=lambda storage, loc: storage.cuda())
        model.load_state_dict(checkpoint["state_dict"], strict=False)
    
    num_params = pt.utils.misc.count_parameters(model)[0]
    logger.info(f"Model size: {num_params / 1e6:.02f}M")  

    ## Use AMP
    model, optimizer = apex.amp.initialize(
        model, optimizer, opt_level=hparams.opt_level, verbosity=0, loss_scale=1024
    )

    # Get loss
    loss = criterion_from_list(hparams.criterion).cuda()
    logger.info(f"Loss for this run is: {loss}")

    bce_loss = pt.losses.CrossEntropyLoss(mode="binary").cuda() # Used as a metric
    bce_loss.name = "BCE"

    # Scheduler is an advanced way of planning experiment
    sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(hparams.phases)

    # Init runner 
    runner = pt.fit_wrapper.Runner(
        model,
        optimizer,
        criterion=loss,
        callbacks=[
            pt_clb.Timer(),
            pt_clb.ConsoleLogger(),
            pt_clb.FileLogger(hparams.outdir, logger=logger),
            pt_clb.CheckpointSaver(hparams.outdir, save_name="model.chpn"),
            sheduler,
            PredictViewer(hparams.outdir, num_images=4)
        ],
        metrics=[
            bce_loss,
            pt.metrics.JaccardScore(mode="binary").cuda(),
            # ThrJaccardScore(thr=0.5),
        ],
    )

    if hparams.decoder_warmup_epochs > 0:
        # Freeze encoder
        for p in model.encoder.parameters():
            p.requires_grad = False

        runner.fit(
            train_loader,
            val_loader=val_loader,

            epochs=hparams.decoder_warmup_epochs,
            steps_per_epoch=10 if hparams.debug else None,
            val_steps=10 if hparams.debug else None,
            # val_steps=50 if hparams.debug else None,
        )

        # Unfreeze all
        for p in model.parameters():
            p.requires_grad = True

        # Reinit again to avoid NaN's in loss
        optimizer = optimizer_from_name(hparams.optim)(
            model.parameters(),
            weight_decay=hparams.weight_decay
        )
        model, optimizer = apex.amp.initialize(
            model, optimizer, opt_level=hparams.opt_level, verbosity=0, loss_scale=2048
        )
        runner.state.model = model
        runner.state.optimizer = optimizer

    # Train both encoder and decoder
    runner.fit(
        train_loader,
        val_loader=val_loader,
        start_epoch=hparams.decoder_warmup_epochs,
        epochs=sheduler.tot_epochs,
        steps_per_epoch=10 if hparams.debug else None,
        val_steps=10 if hparams.debug else None,
    )