start_epoch, best_loss = pt_utils.load_checkpoint(
            model, optimizer, filename=args.checkpoint.split(".")[0])

        lr_scheduler = lr_sched.LambdaLR(optimizer,
                                         lr_lbmd,
                                         last_epoch=start_epoch)
        bnm_scheduler = pt_utils.BNMomentumScheduler(model,
                                                     bnm_lmbd,
                                                     last_epoch=start_epoch)

    model_fn = model_fn_decorator(nn.CrossEntropyLoss())

    trainer = pt_utils.Trainer(model,
                               model_fn,
                               optimizer,
                               checkpoint_name="sem_seg_checkpoint",
                               best_name="sem_seg_best",
                               lr_scheduler=lr_scheduler,
                               bnm_scheduler=bnm_scheduler,
                               eval_frequency=10)

    trainer.train(start_epoch,
                  args.epochs,
                  train_loader,
                  test_loader,
                  best_loss=best_loss)

    if start_epoch == args.epochs:
        test_loader.dataset.data_precent = 1.0
        _ = trainer.eval_epoch(start_epoch, test_loader)
示例#2
0
    else:
        lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd)
        bnm_scheduler = pt_utils.BNMomentumScheduler(model, bn_lambda=bn_lbmd)

        best_loss = 1e10
        start_epoch = 1

    model_fn = model_fn_decorator(nn.CrossEntropyLoss())

    viz = pt_utils.VisdomViz(port=args.visdom_port)
    viz.text(str(vars(args)))

    trainer = pt_utils.Trainer(model,
                               model_fn,
                               optimizer,
                               checkpoint_name="checkpoints/pointnet2_cls",
                               best_name="checkpoints/pointnet2_cls_best",
                               lr_scheduler=lr_scheduler,
                               bnm_scheduler=bnm_scheduler,
                               viz=viz)

    trainer.train(0,
                  start_epoch,
                  args.epochs,
                  train_loader,
                  test_loader,
                  best_loss=best_loss)

    if start_epoch == args.epochs:
        _ = trainer.eval_epoch(test_loader)
示例#3
0
            model, bn_lambda=bn_lbmd, last_epoch=start_epoch
        )
    else:
        lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd)
        bnm_scheduler = pt_utils.BNMomentumScheduler(model, bn_lambda=bn_lbmd)

        best_loss = 1e10
        start_epoch = 1

    model_fn = model_fn_decorator(nn.CrossEntropyLoss())

    trainer = pt_utils.Trainer(
        model,
        model_fn,
        optimizer,
        checkpoint_name="checkpoints/single_layer",
        best_name="checkpoints/single_layer_best",
        lr_scheduler=lr_scheduler,
        bnm_scheduler=bnm_scheduler
    )

    trainer.train(
        start_epoch,
        args.epochs,
        train_loader,
        test_loader,
        best_loss=best_loss
    )

    if start_epoch == args.epochs:
        _ = trainer.eval_epoch(start_epoch, test_loader)