Example #1
0
def do_train(cfg):
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    train_loader = instantiate(cfg.dataloader.train)

    model = create_ddp_model(model, **cfg.train.ddp)
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(
        model, train_loader, optim)
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        optimizer=optim,
        trainer=trainer,
    )
    trainer.register_hooks([
        hooks.IterationTimer(),
        hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
        hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
        if comm.is_main_process() else None,
        hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
        hooks.PeriodicWriter(
            default_writers(cfg.train.output_dir, cfg.train.max_iter),
            period=cfg.train.log_period,
        ) if comm.is_main_process() else None,
    ])

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=True)
    start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)
Example #2
0
def do_train(args, cfg):
    """
    Args:
        cfg: an object with the following attributes:
            model: instantiate to a module
            dataloader.{train,test}: instantiate to dataloaders
            dataloader.evaluator: instantiate to evaluator for test set
            optimizer: instantaite to an optimizer
            lr_multiplier: instantiate to a fvcore scheduler
            train: other misc config defined in `configs/common/train.py`, including:
                output_dir (str)
                init_checkpoint (str)
                amp.enabled (bool)
                max_iter (int)
                eval_period, log_period (int)
                device (str)
                checkpointer (dict)
                ddp (dict)
    """
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    train_loader = instantiate(cfg.dataloader.train)

    model = create_ddp_model(model, **cfg.train.ddp)
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(
        model, train_loader, optim)
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        optimizer=optim,
        trainer=trainer,
    )
    trainer.register_hooks([
        hooks.IterationTimer(),
        hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
        hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
        if comm.is_main_process() else None,
        hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
        hooks.PeriodicWriter(
            default_writers(cfg.train.output_dir, cfg.train.max_iter),
            period=cfg.train.log_period,
        ) if comm.is_main_process() else None,
    ])

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
    if args.resume and checkpointer.has_checkpoint():
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration
        start_iter = trainer.iter + 1
    else:
        start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)
def evaluate(dataset):
    register_one_set(dataset)

    cfg = get_my_cfg()
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7

    model = DefaultTrainer.build_model(
        cfg)  # just built the model without weights
    checkpoiner = DetectionCheckpointer(model, cfg.OUTPUT_DIR)
    checkpoiner.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=False)  # loaded the weights we had trained

    evaluator = COCOEvaluator(dataset, ("bbox", ),
                              False,
                              output_dir=os.path.join("output", "evaluate"))
    loader = build_detection_test_loader(cfg, dataset)
    print(inference_on_dataset(model, loader, evaluator))
Example #4
0
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(
        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
    )
    start_iter = (
        checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
    )
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(
        checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
    )

    writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement in a small training loop
    data_loader = build_detection_train_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            storage.iter = iteration

            loss_dict = model(data)
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
            scheduler.step()

            if (
                cfg.TEST.EVAL_PERIOD > 0
                and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
                and iteration != max_iter - 1
            ):
                do_test(cfg, model)
                # Compared to "train_net.py", the test results are not dumped to EventStorage
                comm.synchronize()

            if iteration - start_iter > 5 and (
                (iteration + 1) % 20 == 0 or iteration == max_iter - 1
            ):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)
Example #5
0
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=cfg.SOLVER.BASE_LR)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)

    checkpointer = DetectionCheckpointer(
        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
    )
    start_iter = (
        checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
    )
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(
        checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
    )

    writers = [
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR)
    ]
    
    data_loader = build_detection_train_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss for loss in loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            storage.put_scalars(total_loss=losses, **loss_dict)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)

            if (
                cfg.TEST.EVAL_PERIOD > 0
                and iteration % cfg.TEST.EVAL_PERIOD == 0
                and iteration != max_iter
            ):
                do_test(cfg, model)
                scheduler.step()

            if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)
Example #6
0
def load_model(config_file, model_weights, model_device):
    print('Loading cfg')
    args = [
        '--config-file',
        config_file,
        'MODEL.WEIGHTS',
        model_weights,
        'MODEL.DEVICE',
        model_device,
    ]
    args = default_argument_parser().parse_args(args)
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)

    print('Loading model')
    model = build_model(cfg)
    checkpointer = DetectionCheckpointer(model, cfg.OUTPUT_DIR)
    checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=True)
    return model
Example #7
0
    def __init__(self, context: PyTorchTrialContext):
        self.context = context

        self.cfg = self.setup_cfg()
        model = build_model(self.cfg)

        checkpointer = DetectionCheckpointer(
            model, self.cfg.OUTPUT_DIR
        )
        checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=False)
        self.model = self.context.wrap_model(checkpointer.model)

        optimizer = build_optimizer(self.cfg, self.model)
        self.optimizer = self.context.wrap_optimizer(optimizer)

        self.scheduler = build_lr_scheduler(self.cfg, self.optimizer)
        self.scheduler = self.context.wrap_lr_scheduler(self.scheduler,LRScheduler.StepMode.STEP_EVERY_BATCH)

        self.dataset_name = self.cfg.DATASETS.TEST[0]
        self.evaluators = get_evaluator(self.cfg, self.dataset_name, self.context.get_hparam("output_dir"), self.context.get_hparam('fake_data'))
        self.val_reducer = self.context.wrap_reducer(EvaluatorReducer(self.evaluators), for_training=False)

        self.context.experimental.disable_dataset_reproducibility_checks()
Example #8
0
def do_train(cfg, args, model, resume=False):
    # default batch size is 16
    model.train()

    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model,
                                         cfg.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)
    max_iter = cfg.SOLVER.MAX_ITER
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ]
               #if comm.is_main_process()
               #else []
               )

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement in a small training loop
    #logger.info("Starting training from iteration {}".format(start_iter))

    iters = 0
    iter_cnt = 0
    iter_sample_start = 1
    iter_sample_end = 20
    iter_end = 300
    start_time, end_time = 0, 0
    sample_iters = iter_sample_end - iter_sample_start + 1

    if args.scheduler:
        if args.scheduler_baseline:
            grc.memory.clean()
            grc.compressor.clean()
            grc.memory.partition()
        else:
            from mergeComp_dl.torch.scheduler.scheduler import Scheduler
            Scheduler(grc, memory_partition, args)

    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iters += 1
            iter_cnt += 1
            if iters == iter_end:
                break

            if hvd.local_rank() == 0 and iter_cnt == iter_sample_start:
                torch.cuda.synchronize()
                start_time = time_()

            storage.iter = iteration
            #torch.cuda.synchronize()
            #iter_start_time = time_()

            loss_dict = model(data)

            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            #torch.cuda.synchronize()
            #iter_model_time = time_()

            #loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            #losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            #if comm.is_main_process():
            #    storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)

            #print("loss dict:", loss_dict, "losses:", losses, "reduced loss dict:", loss_dict_reduced, "reduced losses:", losses_reduced)
            losses.backward()

            #torch.cuda.synchronize()
            #iter_backward_time = time_()

            optimizer.step()
            optimizer.zero_grad()

            #torch.cuda.synchronize()
            #print("Iteration: {}\tmodel time: {:.3f} \tbackward time: {:.3f}\tFP+BP Time: {:.3f}\tstep time: {:.3f}\tData size: {}".format(
            #    iteration,
            #    (iter_model_time - iter_start_time),
            #    (iter_backward_time - iter_model_time),
            #    (iter_backward_time - iter_start_time),
            #    time_() - iter_start_time,
            #    len(data)))

            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()
            if args.compress:
                grc.memory.update_lr(optimizer.param_groups[0]['lr'])

            if hvd.local_rank() == 0 and iter_cnt == iter_sample_end:
                torch.cuda.synchronize()
                end_time = time_()
                iter_cnt = 0
                print(
                    "Iterations: {}\tTime: {:.3f} s\tTraining speed: {:.3f} iters/s"
                    .format(sample_iters, end_time - start_time,
                            sample_iters / (end_time - start_time)))

            if (cfg.TEST.EVAL_PERIOD > 0
                    and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
                    and iteration != max_iter - 1):
                do_test(cfg, model)
Example #9
0
def do_train(cfg, args, myargs):
    run_func = cfg.start.get('run_func', 'train_func')
    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH
    max_epoch = cfg.start.max_epoch
    ASPECT_RATIO_GROUPING = cfg.start.ASPECT_RATIO_GROUPING
    NUM_WORKERS = cfg.start.NUM_WORKERS
    checkpoint_period = cfg.start.checkpoint_period
    dataset_mapper = cfg.start.dataset_mapper
    resume_ckpt_dir = get_attr_kwargs(cfg.start,
                                      'resume_ckpt_dir',
                                      default=None)
    resume_ckpt_epoch = get_attr_kwargs(cfg.start,
                                        'resume_ckpt_epoch',
                                        default=0)
    resume_ckpt_iter_every_epoch = get_attr_kwargs(
        cfg.start, 'resume_ckpt_iter_every_epoch', default=0)

    cfg.defrost()
    cfg.DATASETS.TRAIN = (dataset_name, )
    cfg.SOLVER.IMS_PER_BATCH = IMS_PER_BATCH
    cfg.DATALOADER.ASPECT_RATIO_GROUPING = ASPECT_RATIO_GROUPING
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.freeze()

    # build dataset
    mapper = build_dataset_mapper(dataset_mapper)
    data_loader = build_detection_train_loader(cfg, mapper=mapper)
    metadata = MetadataCatalog.get(dataset_name)
    num_images = metadata.get('num_images')
    iter_every_epoch = num_images // IMS_PER_BATCH
    max_iter = iter_every_epoch * max_epoch

    model = build_trainer(cfg,
                          myargs=myargs,
                          iter_every_epoch=iter_every_epoch,
                          img_size=dataset_mapper.img_size,
                          dataset_name=dataset_name,
                          train_bs=IMS_PER_BATCH,
                          max_iter=max_iter)
    model.train()

    # optimizer = build_optimizer(cfg, model)
    optims_dict = model.build_optimizer()
    # scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model.get_saved_model(),
                                         cfg.OUTPUT_DIR, **optims_dict)
    if args.resume:
        resume_ckpt_dir = model._get_ckpt_path(
            ckpt_dir=resume_ckpt_dir,
            ckpt_epoch=resume_ckpt_epoch,
            iter_every_epoch=resume_ckpt_iter_every_epoch)
        start_iter = (
            checkpointer.resume_or_load(resume_ckpt_dir).get("iteration", -1) +
            1)
        if get_attr_kwargs(args, 'finetune', default=False):
            start_iter = 0
    else:
        start_iter = 0

    model.after_resume()

    if run_func != 'train_func':
        eval(f'model.{run_func}()')
        exit(0)

    checkpoint_period = eval(checkpoint_period,
                             dict(iter_every_epoch=iter_every_epoch))
    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 checkpoint_period,
                                                 max_iter=max_iter)
    logger.info("Starting training from iteration {}".format(start_iter))

    with EventStorage(start_iter) as storage:
        pbar = zip(data_loader, range(start_iter, max_iter))
        if comm.is_main_process():
            pbar = tqdm.tqdm(
                pbar,
                desc=f'do_train, {myargs.args.time_str_suffix}, '
                f'iters {iter_every_epoch} * bs {IMS_PER_BATCH} = imgs {iter_every_epoch*IMS_PER_BATCH}',
                file=myargs.stdout,
                initial=start_iter,
                total=max_iter)

        for data, iteration in pbar:
            comm.synchronize()
            iteration = iteration + 1
            storage.step()

            model.train_func(data, iteration - 1, pbar=pbar)

            periodic_checkpointer.step(iteration)
            pass

    comm.synchronize()
Example #10
0
    return cfg


args = [
    '--config-file',
    'configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml',
    'MODEL.WEIGHTS',
    'models/model_final_280758.pkl',
]
args = default_argument_parser().parse_args(args)
cfg = setup(args)
model = build_model(cfg)

from detectron2.checkpoint import DetectionCheckpointer
checkpointer = DetectionCheckpointer(model, cfg.OUTPUT_DIR)
start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=True).get(
    "iteration", -1) + 1

import torch
from ESNAC.graph import get_graph_resnet, get_plot
from detectron2.modeling.backbone import ESNACArchitecture

resnet = model.backbone.bottom_up
arch = ESNACArchitecture(*get_graph_resnet(resnet))
'''
input = torch.rand(1, 3, 256, 256)
o1 = resnet(input)
o2 = arch(input)
print(torch.sum(torch.abs(o1['res2'] - o2['res2'])))
print(torch.sum(torch.abs(o1['res3'] - o2['res3'])))
print(torch.sum(torch.abs(o1['res4'] - o2['res4'])))
print(torch.sum(torch.abs(o1['res5'] - o2['res5'])))
Example #11
0
class LOFARTrainer(SimpleTrainer):
    """
    A trainer with default training logic. Compared to `SimpleTrainer`, it
    contains the following logic in addition:

    1. Create model, optimizer, scheduler, dataloader from the given config.
    2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
    3. Register a few common hooks.

    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.

    The code of this class has been annotated about restrictive assumptions it mades.
    When they do not work for you, you're encouraged to:

    1. Overwrite methods of this class, OR:
    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
       nothing else. You can then add your own hooks if needed. OR:
    3. Write your own training loop similar to `tools/plain_train_net.py`.

    Also note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in detectron2.
    To obtain more stable behavior, write your own training logic with other public APIs.

    Attributes:
        scheduler:
        checkpointer (DetectionCheckpointer):
        cfg (CfgNode):

    Examples:

    .. code-block:: python

        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
        trainer.train()
    """
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        logger = logging.getLogger("detectron2")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for d2
            setup_logger()
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)
        super().__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        If `resume==True`, and last checkpoint exists, resume from it.

        Otherwise, load a model specified by the config.

        Args:
            resume (bool): whether to do resume or not
        """
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        self.start_iter = (self.checkpointer.resume_or_load(
            self.cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(
                self.cfg,
                self.model,
                evaluators=[
                    LOFAREvaluator(
                        t,
                        cfg.OUTPUT_DIR,
                        sigmabox=cfg.SIGMABOX,
                        segmentation_dir=
                        f'/data1/mostertrij/data/cache/segmentation_maps_{cfg.TEST.REMOVE_THRESHOLD}',
                        remove_unresolved=cfg.TEST.REMOVE_UNRESOLVED)
                    for t in self.cfg.DATASETS.TEST
                ])
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(
            hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results,
                           cfg.TEST.EXTRA_EVAL))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers()))
        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.

        It is now implemented by:

        .. code-block:: python

            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]

        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            LOFARMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.

        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_iter, self.max_iter)
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:

        It now calls :func:`detectron2.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        if not cfg.MODEL.PRETRAINED_WEIGHTS == "":
            assert os.path.exists(
                cfg.MODEL.PRETRAINED_WEIGHTS
            ), f'Pretrain path does not exist: {cfg.MODEL.PRETRAINED_WEIGHTS}'

            # Load the pretrained model
            ckpt = torch.load(cfg.MODEL.PRETRAINED_WEIGHTS)
            state = ckpt['resnet50_parameters']

            # Change keynames of simCLR pretrained model state_dict to match detectron2 state_dict
            # Sidenote: Iterating over state is equivalent to iterating over state.keys()
            pretrained_keys = deepcopy(list(state.keys()))
            destination_keys = deepcopy(
                list(model.backbone.bottom_up.state_dict()))
            for i, (old_key, dest_key) in enumerate(
                    zip(pretrained_keys, destination_keys)):
                #if i<3: print(old_key,' ,  ',dest_key) # Show the slight difference in keynames
                assert old_key.split('.')[1:] == dest_key.split('.')[1:]
                state[dest_key] = state[old_key]
            [state.pop(k) for k in pretrained_keys]
            # Delete old keys

            # Load partial model weights
            model.backbone.bottom_up.load_state_dict(state, strict=False)

        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:

        It now calls :func:`detectron2.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        It now calls :func:`detectron2.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        """
        Returns:
            DatasetEvaluator

        It is not implemented by default.
        """
        raise NotImplementedError(
            "Please either implement `build_evaluator()` in subclasses, or pass "
            "your evaluator as arguments to `DefaultTrainer.test()`.")

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.

        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        #print(evaluators)
        #print(type(evaluators))
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        #print(evaluators)
        if evaluators is not None:
            assert len(
                cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                    len(cfg.DATASETS.TEST), len(evaluators))

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method.")
                    results[dataset_name] = {}
                    continue
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i)
                logger.info("Evaluation results for {} in csv format:".format(
                    dataset_name))
                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]
        return results

    @staticmethod
    def auto_scale_workers(cfg, num_workers: int):
        """
        When the config is defined for certain number of workers (according to
        ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
        workers currently in use, returns a new cfg where the total batch size
        is scaled so that the per-GPU batch size stays the same as the
        original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.

        Other config options are also scaled accordingly:
        * training steps and warmup steps are scaled inverse proportionally.
        * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.

        It returns the original config if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.

        Returns:
            CfgNode: a new config
        """
        old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
        if old_world_size == 0 or old_world_size == num_workers:
            return cfg
        cfg = cfg.clone()
        frozen = cfg.is_frozen()
        cfg.defrost()

        assert (cfg.SOLVER.IMS_PER_BATCH %
                old_world_size == 0), "Invalid REFERENCE_WORLD_SIZE in config!"
        scale = num_workers / old_world_size
        bs = cfg.SOLVER.IMS_PER_BATCH = int(
            round(cfg.SOLVER.IMS_PER_BATCH * scale))
        lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
        max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER /
                                                   scale))
        warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(
            round(cfg.SOLVER.WARMUP_ITERS / scale))
        cfg.SOLVER.STEPS = tuple(
            int(round(s / scale)) for s in cfg.SOLVER.STEPS)
        cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
        cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers  # maintain invariant
        logger = logging.getLogger(__name__)
        logger.info(
            f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
            f"max_iter={max_iter}, warmup={warmup_iter}.")

        if frozen:
            cfg.freeze()
        return cfg
Example #12
0
def do_train(cfg, model, resume=False):
    """

    # TODO: Write docstring
    """
    # Set the model to train
    model.train()

    # Create torch optimiser & schedulars
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    # Create a torch checkpointer
    checkpointer = DetectionCheckpointer(
        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
    )

    # Create starting checkpoint i.e. pre-trained model using weights from config
    start_iter = (
        checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
    )

    # Define the number of iterations
    max_iter = cfg.SOLVER.MAX_ITER

    # Create a periodic checkpointer at the configured period
    periodic_checkpointer = PeriodicCheckpointer(
        checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
    )

    # Export checkpoint data to terminal, JSON & tensorboard files
    writers = (
        [
            CommonMetricPrinter(max_iter),
            JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(cfg.OUTPUT_DIR),
        ]
        if comm.is_main_process()
        else []
    )

    # Create a data loader to supply the model with training data
    data_loader = build_detection_train_loader(cfg)

    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)
          
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
            scheduler.step()

            # If eval period has been set, run test at defined interval
            if (
                cfg.TEST.EVAL_PERIOD > 0
                and iteration % cfg.TEST.EVAL_PERIOD == 0
                and iteration != max_iter
            ):
                do_test(cfg, model)
                comm.synchronize()

            if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter):
                logger.debug('Logging iteration and loss to Weights & Biases')
                wandb.log({"iteration": iteration})
                wandb.log({"total_loss": losses_reduced})
                wandb.log(loss_dict_reduced)

                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)
 def build_model(self) -> nn.Module:
     model = build_model(self.cfg)
     fi = self.context.get_hparam('data_loc')
     ch = DetectionCheckpointer(model)
     ch.resume_or_load(path=fi, resume=False)
     return ch.model
Example #14
0
class DefaultTrainer(SimpleTrainer):
    """
    A trainer with default training logic. Compared to `SimpleTrainer`, it
    contains the following logic in addition:

    1. Create model, optimizer, scheduler, dataloader from the given config.
    2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
    3. Register a few common hooks.

    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.

    The code of this class has been annotated about restrictive assumptions it mades.
    When they do not work for you, you're encouraged to:

    1. Overwrite methods of this class, OR:
    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
       nothing else. You can then add your own hooks if needed. OR:
    3. Write your own training loop similar to `tools/plain_train_net.py`.

    Also note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in detectron2.
    To obtain more stable behavior, write your own training logic with other public APIs.

    Attributes:
        scheduler:
        checkpointer (DetectionCheckpointer):
        cfg (CfgNode):

    Examples:

    .. code-block:: python

        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
        trainer.train()
    """
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        # Assume these objects must be constructed in this order.
        # self.apply_mul_opts = True if cfg.MODEL.ROI_MASK_HEAD.RECON_NET.NAME != "" else False
        self.apply_mul_opts = False
        model = self.build_model(cfg)
        if self.apply_mul_opts:
            optimizer_main = self.build_optimizer(cfg, model, ty_opt="M")
            optimizer_recon = self.build_optimizer(cfg, model, ty_opt="A")
            optimizer = [optimizer_main, optimizer_recon]

            self.scheduler_main = self.build_lr_scheduler(cfg, optimizer_main)
            self.scheduler_recon = self.build_lr_scheduler(
                cfg, optimizer_recon)
            self.checkpointer = DetectionCheckpointer(
                # Assume you want to save checkpoints together with logs/statistics
                model,
                cfg.OUTPUT_DIR,
                optimizer_gen=optimizer_main,
                optimizer_dis=optimizer_recon,
                scheduler_gen=self.scheduler_main,
                scheduler_dis=self.scheduler_recon,
            )
        else:
            optimizer = self.build_optimizer(cfg,
                                             model,
                                             ty_opt=cfg.SOLVER.OPT_TYPE)
            # optimizer = self.build_optimizer(cfg, model, ty_opt='SGD')
            self.scheduler = self.build_lr_scheduler(cfg, optimizer)
            # Assume no other objects need to be checkpointed.
            # We can later make it checkpoint the stateful hooks
            self.checkpointer = DetectionCheckpointer(
                # Assume you want to save checkpoints together with logs/statistics
                model,
                cfg.OUTPUT_DIR,
                optimizer=optimizer,
                scheduler=self.scheduler,
            )

        logger = logging.getLogger(__name__)
        logger.info("optimizer information:{}".format(type(optimizer)))
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)
        super().__init__(model, data_loader, optimizer, cfg)

        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        If `resume==True`, and last checkpoint exists, resume from it.

        Otherwise, load a model specified by the config.

        Args:
            resume (bool): whether to do resume or not
        """
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        self.start_iter = (self.checkpointer.resume_or_load(
            self.cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN
        if not self.apply_mul_opts:
            ret = [
                hooks.IterationTimer(),
                hooks.LRScheduler(self.optimizer, self.scheduler),
                hooks.PreciseBN(
                    # Run at the same freq as (but before) evaluation.
                    cfg.TEST.EVAL_PERIOD,
                    self.model,
                    # Build a new data loader to not affect training
                    self.build_train_loader(cfg),
                    cfg.TEST.PRECISE_BN.NUM_ITER,
                ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
                else None,
            ]
        else:
            ret = [
                hooks.IterationTimer(),
                hooks.LRScheduler(self.optimizer_gen, self.scheduler_main),
                hooks.LRScheduler(self.optimizer_dis, self.scheduler_recon),
                hooks.PreciseBN(
                    # Run at the same freq as (but before) evaluation.
                    cfg.TEST.EVAL_PERIOD,
                    self.model,
                    # Build a new data loader to not affect training
                    self.build_train_loader(cfg),
                    cfg.TEST.PRECISE_BN.NUM_ITER,
                ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
                else None,
            ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers()))

        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.

        It is now implemented by:

        .. code-block:: python

            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]

        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.

        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_iter, self.max_iter)
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:

        It now calls :func:`detectron2.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))

        dataset_name = cfg.DATASETS.TEST[0].split("_")[0]
        if os.path.exists(
                "{}_codebook.npy".format(dataset_name)
        ) and cfg.MODEL.ROI_MASK_HEAD.RECON_NET.LOAD_CODEBOOK:
            logger.info("Loading recon net and codebook")
            model.roi_heads.recon_net.load_state_dict(
                torch.load("{}_recon_net.pth".format(dataset_name)))
            model.roi_heads.recon_net.vector_dict = np.load(
                "{}_codebook.npy".format(dataset_name))[()]
        return model

    @classmethod
    def build_optimizer(cls, cfg, model, ty_opt=None):
        """
        Returns:
            torch.optim.Optimizer:

        It now calls :func:`detectron2.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model, ty_opt)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        It now calls :func:`detectron2.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        if cfg.DATALOADER.MAPPER == "amodal_and_visible":
            return build_detection_train_loader(cfg,
                                                mapper=AmodalDatasetMapper(
                                                    cfg, is_train=True))
        return build_detection_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        if cfg.DATALOADER.MAPPER == "amodal_and_visible":
            return build_detection_test_loader(cfg,
                                               mapper=AmodalDatasetMapper(
                                                   cfg, is_train=False))
        return build_detection_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        """
        Returns:
            DatasetEvaluator

        It is not implemented by default.
        """
        raise NotImplementedError(
            "Please either implement `build_evaluator()` in subclasses, or pass "
            "your evaluator as arguments to `DefaultTrainer.test()`.")

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.

        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(
                cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                    len(cfg.DATASETS.TEST), len(evaluators))

        logger.info("Version:{}".format(cfg.OUTPUT.TRAIN_VERSION))
        results = OrderedDict()

        dataset = cfg.DATASETS.TEST[0].split("_")[0]
        if (not cfg.MODEL.ROI_MASK_HEAD.RECON_NET.LOAD_CODEBOOK or not os.path.exists("{}_codebook.npy".format(dataset))) \
                and cfg.MODEL.ROI_MASK_HEAD.RECON_NET.NAME != "":
            for idx, dataset_name in enumerate(cfg.DATASETS.TRAIN):
                embedding_dataloader = cls.build_test_loader(cfg, dataset_name)
                model = embedding_inference_on_train_dataset(
                    model, embedding_dataloader)
            logger.info("Start KMEANS clustering")
            model.roi_heads.recon_net.cluster()
            logger.info("KMEANS clustering has finished")
            torch.save(
                model.roi_heads.recon_net.state_dict(),
                "{}/{}_recon_net.pth".format(model._cfg.OUTPUT_DIR, dataset))
            logger.info("Recon net saved")
            np.save(
                '{}/{}_codebook.npy'.format(model._cfg.OUTPUT_DIR, dataset),
                model.roi_heads.recon_net.vector_dict)

        # if cfg.MODEL.ROI_MASK_HEAD.RECON_NET.NAME == "General_Recon_Net":
        #     # torch.save(model.roi_heads.recon_net.state_dict(), "{}/recon_net.pth".format(model._cfg.OUTPUT_DIR))
        #     # logger.info("Recon net saved")
        #     dataset = "d2sa"
        #     if os.path.exists("{}_codebook.npy".format(dataset)) and :
        #         model.roi_heads.recon_net.load_state_dict(torch.load("{}_recon_net.pth".format(dataset)))
        #         model.roi_heads.recon_net.vector_dict = np.load("{}_codebook.npy".format(dataset))[()]
        #         logger.info("codebook loaded")
        #
        #     else:
        #         for idx, dataset_name in enumerate(cfg.DATASETS.TRAIN):
        #             # if idx == 0:
        #             #     continue
        #             embedding_dataloader = cls.build_test_loader(cfg, dataset_name)
        #             model = embedding_inference_on_train_dataset(model, embedding_dataloader)
        #         logger.info("Start KMEANS clustering")
        #         model.roi_heads.recon_net.cluster()
        #         logger.info("KMEANS clustering has finished")
        #         torch.save(model.roi_heads.recon_net.state_dict(), "{}/recon_net.pth".format(model._cfg.OUTPUT_DIR))
        #         logger.info("Recon net saved")
        #         np.save('{}/codebook.npy'.format(model._cfg.OUTPUT_DIR), model.roi_heads.recon_net.vector_dict)

        #     codebook = model.roi_heads.recon_net.vector_dict
        #     # space = []
        #     # label = []
        #     # name = []
        #     # for i in range(10):
        #     #     name.append(i)
        #     #     space.append(codebook[i + 1][: 256])
        #     #     label.append(torch.ones(256) * i)
        #     # space = torch.cat(space, dim=0)
        #     # label = torch.cat(label, dim=0)
        #     for cls_id in range(1, 21):
        #         num_groups = 6
        #         space = codebook[cls_id]
        #         kmeans = KMeans(n_clusters=num_groups)
        #         kmeans.fit(space.cpu())
        #
        #         coord = TSNE(random_state=20200605).fit_transform(np.array(space.cpu().data))
        #         vis.viz.scatter(coord, win="memory{}".format(cls_id +20), opts=dict(legend=[0,1,2,3,4,5], markersize=5), Y=kmeans.labels_ + 1, name="{}".format(cls_id))
        #
        #         tensor = torch.FloatTensor(kmeans.cluster_centers_).cuda().view(num_groups, 8, 4, 4)
        #         mask = model.roi_heads.recon_net.decode(tensor)
        #         vis.images(mask, win_name='mask{}'.format(cls_id))
        #         ipdb.set_trace()
        # ipdb.set_trace()

        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method.")
                    results[dataset_name] = {}
                    continue

            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i)
                logger.info("Evaluation results for {} in csv format:".format(
                    dataset_name))

                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]

        return results
Example #15
0
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model,
                                         cfg.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)
    checkpointer_spot = DetectionCheckpointer(model,
                                              '/opt/ml/checkpoints',
                                              optimizer=optimizer,
                                              scheduler=scheduler)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)
    periodic_checkpointer_spot = PeriodicCheckpointer(
        checkpointer_spot, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement
    #     data_loader = build_detection_train_loader(cfg)
    data_loader = build_detection_train_loader(
        cfg,
        #    mapper=DatasetMapper(cfg, is_train=True
        #                         , augmentations=[
        #        T.Resize((1024, 1024)),
        #        T.RandomBrightness(.75,1.25),
        #        T.RandomFlip(),
        #        T.RandomSaturation(.75,1.25)
        #    ]
    )
    #     )
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if (cfg.TEST.EVAL_PERIOD > 0
                    and iteration % cfg.TEST.EVAL_PERIOD == 0
                    and iteration != max_iter):
                do_test(cfg, model)
                # Compared to "train_net.py", the test results are not dumped to EventStorage
                comm.synchronize()
            if iteration % 500 == 0:
                try:
                    torch.save(model.state_dict(),
                               f'{cfg.OUTPUT_DIR}/model_{iteration}.pth')
                except:
                    print('save failed')

            if iteration - start_iter > 5 and (iteration % 20 == 0
                                               or iteration == max_iter):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)
            periodic_checkpointer_spot.step(iteration)
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
logger = logging.getLogger("frcn")
if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
    setup_logger()

model = build_model(cfg)
logger.info("Model:\n{}".format(model))
model.train()
optimizer = build_optimizer(cfg, model)
scheduler = build_lr_scheduler(cfg, optimizer)

checkpointer = DetectionCheckpointer(
    model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
)
start_iter = (
    checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=False).get("iteration", -1) + 1
)
max_iter = cfg.SOLVER.MAX_ITER

periodic_checkpointer = PeriodicCheckpointer(
    checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)

writers = (
    [
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ]
    if comm.is_main_process()
    else []
Example #17
0
def do_relation_train(cfg, model, resume=False):
    model.train()
    for param in model.named_parameters():
        param[1].requires_grad = False
    for param in model.named_parameters():
        for trainable in cfg.MODEL.TRAINABLE:
            if param[0].startswith(trainable):
                param[1].requires_grad = True
                break

        if param[0] == "relation_heads.instance_head.semantic_embed.weight" or \
            param[0] == "relation_heads.pair_head.semantic_embed.weight" or \
            param[0] == "relation_heads.predicate_head.semantic_embed.weight" or \
            param[0] == "relation_heads.triplet_head.ins_embed.weight" or \
            param[0] == "relation_heads.triplet_head.pred_embed.weight" or \
            param[0] == "relation_heads.subpred_head.sub_embed.weight" or \
            param[0] == "relation_heads.subpred_head.pred_embed.weight" or \
            param[0] == "relation_heads.predobj_head.pred_embed.weight" or \
            param[0] == "relation_heads.predobj_head.obj_embed.weight" or \
            param[0].startswith("relation_heads.predicate_head.freq_bias.obj_baseline.weight"):
            param[1].requires_grad = False

    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)
    metrics_sum_dict = {
        'relation_cls_tp_sum': 0,
        'relation_cls_p_sum': 0.00001,
        'pred_class_tp_sum': 0,
        'pred_class_p_sum': 0.00001,
        'gt_class_tp_sum': 0,
        'gt_class_p_sum': 0.00001,
        'raw_pred_class_tp_sum': 0,
        'raw_pred_class_p_sum': 0.00001,
        'instance_tp_sum':0,
        'instance_p_sum': 0.00001,
        'instance_g_sum':0.00001,
        'subpred_tp_sum': 0,
        'subpred_p_sum': 0.00001,
        'subpred_g_sum': 0.00001,
        'predobj_tp_sum': 0,
        'predobj_p_sum': 0.00001,
        'predobj_g_sum': 0.00001,
        'pair_tp_sum':0,
        'pair_p_sum': 0.00001,
        'pair_g_sum':0.00001,
        'confidence_tp_sum': 0,
        'confidence_p_sum': 0.00001,
        'confidence_g_sum': 0.00001,
        'predicate_tp_sum': 0,
        'predicate_tp20_sum': 0,
        'predicate_tp50_sum': 0,
        'predicate_tp100_sum': 0,
        'predicate_p_sum': 0.00001,
        'predicate_p20_sum': 0.00001,
        'predicate_p50_sum': 0.00001,
        'predicate_p100_sum': 0.00001,
        'predicate_g_sum': 0.00001,
        'triplet_tp_sum': 0,
        'triplet_tp20_sum': 0,
        'triplet_tp50_sum': 0,
        'triplet_tp100_sum': 0,
        'triplet_p_sum': 0.00001,
        'triplet_p20_sum': 0.00001,
        'triplet_p50_sum': 0.00001,
        'triplet_p100_sum': 0.00001,
        'triplet_g_sum': 0.00001,
    }
    checkpointer = DetectionCheckpointer(
        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler, metrics_sum_dict=metrics_sum_dict
    )
    start_iter = (checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    # state_dict=torch.load(cfg.MODEL.WEIGHTS).pop("model")
    # model.load_state_dict(state_dict,strict=False)
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)

    # relation_cls_state_dict=torch.load(cfg.MODEL.WEIGHTS).pop("model")
    # for param in model.named_parameters():
    #     if param[0] not in relation_cls_state_dict:
    #         print(param[0])
    # model.load_state_dict(relation_cls_state_dict,strict=False)

    writers = (
        [
            CommonMetricPrinter(max_iter),
            JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(cfg.OUTPUT_DIR),
        ]
        if comm.is_main_process()
        else []
    )
    metrics_pr_dict={}
    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement
    data_loader = build_detection_train_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    acumulate_losses=0
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            print(iteration)
            iteration = iteration + 1
            storage.step()
            if True:
            # try:
                pred_instances, results_dict, losses_dict, metrics_dict = model(data,iteration,mode="relation",training=True)
                losses = sum(loss for loss in losses_dict.values())
                assert torch.isfinite(losses).all(), losses_dict
                #print(losses_dict)

                loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(losses_dict).items()}
                losses_reduced = sum(loss for loss in loss_dict_reduced.values())
                acumulate_losses += losses_reduced
                if comm.is_main_process():
                    storage.put_scalars(acumulate_losses=acumulate_losses/(iteration-start_iter),total_loss=losses_reduced, **loss_dict_reduced)

                if 'relation_cls_tp' in metrics_dict:
                    metrics_sum_dict['relation_cls_tp_sum']+=metrics_dict['relation_cls_tp']
                    metrics_sum_dict['relation_cls_p_sum'] += metrics_dict['relation_cls_p']
                    metrics_pr_dict['relation_cls_precision'] = metrics_sum_dict['relation_cls_tp_sum'] / metrics_sum_dict['relation_cls_p_sum']
                if 'pred_class_tp' in metrics_dict:
                    metrics_sum_dict['pred_class_tp_sum']+=metrics_dict['pred_class_tp']
                    metrics_sum_dict['pred_class_p_sum'] += metrics_dict['pred_class_p']
                    metrics_pr_dict['pred_class_precision'] = metrics_sum_dict['pred_class_tp_sum'] / metrics_sum_dict['pred_class_p_sum']
                if 'raw_pred_class_tp' in metrics_dict:
                    metrics_sum_dict['raw_pred_class_tp_sum']+=metrics_dict['raw_pred_class_tp']
                    metrics_sum_dict['raw_pred_class_p_sum'] += metrics_dict['raw_pred_class_p']
                    metrics_pr_dict['raw_pred_class_precision'] = metrics_sum_dict['raw_pred_class_tp_sum'] / metrics_sum_dict['raw_pred_class_p_sum']
                if 'gt_class_tp' in metrics_dict:
                    metrics_sum_dict['gt_class_tp_sum']+=metrics_dict['gt_class_tp']
                    metrics_sum_dict['gt_class_p_sum'] += metrics_dict['gt_class_p']
                    metrics_pr_dict['gt_class_precision'] = metrics_sum_dict['gt_class_tp_sum'] / metrics_sum_dict['gt_class_p_sum']
                if 'instance_tp' in metrics_dict:
                    metrics_sum_dict['instance_tp_sum']+=metrics_dict['instance_tp']
                    metrics_sum_dict['instance_p_sum'] += metrics_dict['instance_p']
                    metrics_sum_dict['instance_g_sum'] += metrics_dict['instance_g']
                    metrics_pr_dict['instance_precision'] = metrics_sum_dict['instance_tp_sum'] / metrics_sum_dict['instance_p_sum']
                    metrics_pr_dict['instance_recall'] = metrics_sum_dict['instance_tp_sum'] / metrics_sum_dict['instance_g_sum']
                if 'subpred_tp' in metrics_dict:
                    metrics_sum_dict['subpred_tp_sum']+=metrics_dict['subpred_tp']
                    metrics_sum_dict['subpred_p_sum'] += metrics_dict['subpred_p']
                    metrics_sum_dict['subpred_g_sum'] += metrics_dict['subpred_g']
                    metrics_pr_dict['subpred_precision'] = metrics_sum_dict['subpred_tp_sum'] / metrics_sum_dict['subpred_p_sum']
                    metrics_pr_dict['subpred_recall'] = metrics_sum_dict['subpred_tp_sum'] / metrics_sum_dict['subpred_g_sum']
                if 'predobj_tp' in metrics_dict:
                    metrics_sum_dict['predobj_tp_sum']+=metrics_dict['predobj_tp']
                    metrics_sum_dict['predobj_p_sum'] += metrics_dict['predobj_p']
                    metrics_sum_dict['predobj_g_sum'] += metrics_dict['predobj_g']
                    metrics_pr_dict['predobj_precision'] = metrics_sum_dict['predobj_tp_sum'] / metrics_sum_dict['predobj_p_sum']
                    metrics_pr_dict['predobj_recall'] = metrics_sum_dict['predobj_tp_sum'] / metrics_sum_dict['predobj_g_sum']

                if 'pair_tp' in metrics_dict:
                    metrics_sum_dict['pair_tp_sum'] += metrics_dict['pair_tp']
                    metrics_sum_dict['pair_p_sum'] += metrics_dict['pair_p']
                    metrics_sum_dict['pair_g_sum'] += metrics_dict['pair_g']
                    metrics_pr_dict['pair_precision'] = metrics_sum_dict['pair_tp_sum'] / metrics_sum_dict['pair_p_sum']
                    metrics_pr_dict['pair_recall'] = metrics_sum_dict['pair_tp_sum'] / metrics_sum_dict['pair_g_sum']
                if 'confidence_tp' in metrics_dict:
                    metrics_sum_dict['confidence_tp_sum']+=metrics_dict['confidence_tp']
                    metrics_sum_dict['confidence_p_sum'] += metrics_dict['confidence_p']
                    metrics_sum_dict['confidence_g_sum'] += metrics_dict['confidence_g']
                    metrics_pr_dict['confidence_precision'] = metrics_sum_dict['confidence_tp_sum'] / metrics_sum_dict['confidence_p_sum']
                    metrics_pr_dict['confidence_recall'] = metrics_sum_dict['confidence_tp_sum'] / metrics_sum_dict['confidence_g_sum']
                if 'predicate_tp' in metrics_dict:
                    metrics_sum_dict['predicate_tp_sum']+=metrics_dict['predicate_tp']
                    metrics_sum_dict['predicate_tp20_sum'] += metrics_dict['predicate_tp20']
                    metrics_sum_dict['predicate_tp50_sum'] += metrics_dict['predicate_tp50']
                    metrics_sum_dict['predicate_tp100_sum'] += metrics_dict['predicate_tp100']
                    metrics_sum_dict['predicate_p_sum'] += metrics_dict['predicate_p']
                    metrics_sum_dict['predicate_p20_sum'] += metrics_dict['predicate_p20']
                    metrics_sum_dict['predicate_p50_sum'] += metrics_dict['predicate_p50']
                    metrics_sum_dict['predicate_p100_sum'] += metrics_dict['predicate_p100']
                    metrics_sum_dict['predicate_g_sum'] += metrics_dict['predicate_g']
                    metrics_pr_dict['predicate_precision'] = metrics_sum_dict['predicate_tp_sum'] / metrics_sum_dict['predicate_p_sum']
                    metrics_pr_dict['predicate_precision20'] = metrics_sum_dict['predicate_tp20_sum'] / metrics_sum_dict['predicate_p20_sum']
                    metrics_pr_dict['predicate_precision50'] = metrics_sum_dict['predicate_tp50_sum'] / metrics_sum_dict['predicate_p50_sum']
                    metrics_pr_dict['predicate_precision100'] = metrics_sum_dict['predicate_tp100_sum'] / metrics_sum_dict['predicate_p100_sum']
                    metrics_pr_dict['predicate_recall'] = metrics_sum_dict['predicate_tp_sum'] / metrics_sum_dict['predicate_g_sum']
                    metrics_pr_dict['predicate_recall20'] = metrics_sum_dict['predicate_tp20_sum'] / metrics_sum_dict['predicate_g_sum']
                    metrics_pr_dict['predicate_recall50'] = metrics_sum_dict['predicate_tp50_sum'] / metrics_sum_dict['predicate_g_sum']
                    metrics_pr_dict['predicate_recall100'] = metrics_sum_dict['predicate_tp100_sum'] / metrics_sum_dict['predicate_g_sum']
                if 'triplet_tp' in metrics_dict:
                    metrics_sum_dict['triplet_tp_sum'] += metrics_dict['triplet_tp']
                    metrics_sum_dict['triplet_tp20_sum'] += metrics_dict['triplet_tp20']
                    metrics_sum_dict['triplet_tp50_sum'] += metrics_dict['triplet_tp50']
                    metrics_sum_dict['triplet_tp100_sum'] += metrics_dict['triplet_tp100']
                    metrics_sum_dict['triplet_p_sum'] += metrics_dict['triplet_p']
                    metrics_sum_dict['triplet_p20_sum'] += metrics_dict['triplet_p20']
                    metrics_sum_dict['triplet_p50_sum'] += metrics_dict['triplet_p50']
                    metrics_sum_dict['triplet_p100_sum'] += metrics_dict['triplet_p100']
                    metrics_sum_dict['triplet_g_sum'] += metrics_dict['triplet_g']
                    metrics_pr_dict['triplet_precision'] = metrics_sum_dict['triplet_tp_sum'] / metrics_sum_dict['triplet_p_sum']
                    metrics_pr_dict['triplet_precision20'] = metrics_sum_dict['triplet_tp20_sum'] / metrics_sum_dict['triplet_p20_sum']
                    metrics_pr_dict['triplet_precision50'] = metrics_sum_dict['triplet_tp50_sum'] / metrics_sum_dict['triplet_p50_sum']
                    metrics_pr_dict['triplet_precision100'] = metrics_sum_dict['triplet_tp100_sum'] / metrics_sum_dict['triplet_p100_sum']
                    metrics_pr_dict['triplet_recall'] = metrics_sum_dict['triplet_tp_sum'] / metrics_sum_dict['triplet_g_sum']
                    metrics_pr_dict['triplet_recall20'] = metrics_sum_dict['triplet_tp20_sum'] / metrics_sum_dict['triplet_g_sum']
                    metrics_pr_dict['triplet_recall50'] = metrics_sum_dict['triplet_tp50_sum'] / metrics_sum_dict['triplet_g_sum']
                    metrics_pr_dict['triplet_recall100'] = metrics_sum_dict['triplet_tp100_sum'] / metrics_sum_dict['triplet_g_sum']

                storage.put_scalars(**metrics_pr_dict, smoothing_hint=False)

                optimizer.zero_grad()
                losses.backward()
                optimizer.step()
                storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
                scheduler.step()

                if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter):
                    for writer in writers:
                        writer.write()
                periodic_checkpointer.step(iteration)
                torch.cuda.empty_cache()
Example #18
0
class LiuyCoCoTrainer(SimpleTrainer):
    def __init__(self, cfg, model=None, data_loader=None):

        if model is not None:
            model = model
        else:
            model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        if data_loader is not None:
            self.data_loader = data_loader
            self.data_len = data_loader.dataset._dataset._lst
        else:
            self.data_loader, self.data_len = self.build_train_loader(cfg)
        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )
        super().__init__(model, self.data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = int((270000 * self.data_len) / 45174)
        self.cfg = cfg
        self.register_hooks(self.build_hooks())

    def reset_model(self, cfg, model):
        """
        :return: except data_loader, reset the model
        """
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )
        del self.model
        self.model = model

        optimizer = self.build_optimizer(cfg, model)
        del self.optimizer
        self.optimizer = optimizer

        scheduler = self.build_lr_scheduler(cfg, optimizer)
        del self.scheduler
        self.scheduler = scheduler

        checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        del self.checkpointer
        self.checkpointer = checkpointer

        self.start_iter = 0
        # self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self._hooks = []
        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        If `resume==True`, and last checkpoint exists, resume from it.
        Otherwise, load a model specified by the config.
        Args:
            resume (bool): whether to do resume or not
        """
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        self.start_iter = (
                self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
                    "iteration", -1
                )
                + 1
        )
        self.start_iter = 0

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers()))
        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.
        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        It is now implemented by:
        .. code-block:: python
            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]
        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            # JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.
        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """

        super().train(self.start_iter, self.max_iter)
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    def train_on_single_data(self, base_model):
        """

        :return:
        """

    @classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:
        It now calls :func:`detectron2.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:
        It now calls :func:`detectron2.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        It now calls :func:`detectron2.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_train_loader(cfg)

    def re_build_train_loader(self, dataset_name, images_per_batch=2):
        """
        Returns:
            iterable
        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        self.cfg.DATASETS.TRAIN = [dataset_name]
        return build_detection_train_loader(self.cfg,images_per_batch)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable
        It now calls :func:`detectron2.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_test_loader(cfg, dataset_name)

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.
        Returns:
        """
        logger = logging.getLogger(__name__)
        logger.info("test")
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader, data_len = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {}
                    continue
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i
                )
                logger.info("Evaluation results for {} in csv format:".format(dataset_name))
                # logger.info("MIOU{}:".format(results_i['miou']))
                print(results_i)
        if len(results) == 1:
            results = list(results.values())[0]
            debug = 1
        return results

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        """
        Create evaluator(s) for a given dataset.
        This uses the special metadata "evaluator_type" associated with each builtin dataset.
        For your own dataset, you can simply create an evaluator manually in your
        script and do not have to worry about the hacky if-else logic here.
        """
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []
        evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
        if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
            evaluator_list.append(
                SemSegEvaluator(
                    dataset_name,
                    distributed=True,
                    num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
                    ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
                    output_dir=output_folder,
                )
            )
        if evaluator_type in ["coco", "coco_panoptic_seg"]:
            evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder))
            evaluator_list.append(Liuy_COCOEvaluator(dataset_name, cfg, True, output_folder))
        if evaluator_type == "coco_panoptic_seg":
            evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
        if evaluator_type == "cityscapes":
            assert (
                    torch.cuda.device_count() >= comm.get_rank()
            ), "CityscapesEvaluator currently do not work with multiple machines."
            # return detectron2.evaluation.cityscapes_evaluation.CityscapesEvaluator (dataset_name)
            return liuy.utils.liuy_cityscapes_evaluation.CityscapesEvaluator(dataset_name)
            # return liuy.liuy_cityscapes_evaluation.CityscapesEvaluator(dataset_name)

        if evaluator_type == "pascal_voc":
            return PascalVOCDetectionEvaluator(dataset_name)
        if evaluator_type == "lvis":
            return LVISEvaluator(dataset_name, cfg, True, output_folder)
        if len(evaluator_list) == 0:
            raise NotImplementedError(
                "no Evaluator for the dataset {} with the type {}".format(
                    dataset_name, evaluator_type
                )
            )
        if len(evaluator_list) == 1:
            return evaluator_list[0]
        return DatasetEvaluators(evaluator_list)

    @classmethod
    def test_with_TTA(cls, cfg, model):
        logger = logging.getLogger("detectron2.trainer")
        # In the end of training, run an evaluation with TTA
        # Only support some R-CNN models.
        logger.info("Running inference with test-time augmentation ...")
        model = GeneralizedRCNNWithTTA(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res
def do_train(cfg, model, cat_heatmap_file, resume=False):
    model.train()

    # select optimizer and learning rate scheduler based on the config
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    # creat checkpointer
    checkpointer = DetectionCheckpointer(
        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
    )
    start_iter = (
        checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
    )
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(
        checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
    )

    # create output writers. Separate TensorBoard writers are created
    # for train and validation sets. This allows easy overlaying of graphs
    # in TensorBoard.
    train_tb_writer = os.path.join(cfg.OUTPUT_DIR, 'train')
    val_tb_writer = os.path.join(cfg.OUTPUT_DIR, 'val')
    train_writers = (
        [
            CommonMetricPrinter(max_iter),
            JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(train_tb_writer),
        ]
        if comm.is_main_process()
        else []
    )
    val_writers = [TensorboardXWriter(val_tb_writer)]


    train_dataset_name = cfg.DATASETS.TRAIN[0]
    train_data_loader = build_detection_train_loader(cfg)
    train_eval_data_loader = build_detection_test_loader(cfg, train_dataset_name)
    val_dataset_name = cfg.DATASETS.TEST[0]
    val_eval_data_loader = build_detection_test_loader(cfg, val_dataset_name, DatasetMapper(cfg,True))
    logger.info("Starting training from iteration {}".format(start_iter))
    train_storage = EventStorage(start_iter)
    val_storage = EventStorage(start_iter)

    # Create the training and validation evaluator objects.
    train_evaluator = get_evaluator(
        cfg, train_dataset_name, os.path.join(cfg.OUTPUT_DIR, "train_inference", train_dataset_name),
        cat_heatmap_file
    )
    val_evaluator = get_evaluator(
        cfg, val_dataset_name, os.path.join(cfg.OUTPUT_DIR, "val_inference", val_dataset_name),
        cat_heatmap_file
    )

    # initialize the best AP50 value
    best_AP50 = 0
    start_time = time.time()
    for train_data, iteration in zip(train_data_loader, range(start_iter, max_iter)):
         # stop if the file stop_running exists in the running directory
         if os.path.isfile('stop_running'):
             os.remove('stop_running')
             break

         iteration = iteration + 1

         # run a step with the training data
         with train_storage as storage:
            model.train()
            storage.step()

            loss_dict = model(train_data)
            losses = sum(loss for loss in loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
            scheduler.step()


            # periodically evaluate the training set and write the results
            if (cfg.TEST.EVAL_PERIOD > 0
                and iteration % cfg.TEST.EVAL_PERIOD == 0
                and iteration != max_iter):

                train_eval_results = inference_on_dataset(model, train_eval_data_loader,
                                                          train_evaluator)
                flat_results = flatten_results(train_eval_results)
                storage.put_scalars(**flat_results)
                comm.synchronize()

            if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter):
                for writer in train_writers:
                    writer.write()
            periodic_checkpointer.step(iteration)

         # run a step with the validation set
         with val_storage as storage:
            storage.step()

            # every 20 iterations evaluate the dataset to collect the loss
            if iteration % 20 == 0 or iteration == max_iter:
                with torch.set_grad_enabled(False):
                     for input, i in zip(val_eval_data_loader , range(1)):
                        loss_dict = model(input)
                        losses = sum(loss for loss in loss_dict.values())
                        assert torch.isfinite(losses).all(), loss_dict

                        loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
                        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

                if comm.is_main_process():
                    storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)

            # periodically evaluate the validation set and write the results
            # check the results against the best results seen and save the parameters for
            # the best result
            if (cfg.TEST.EVAL_PERIOD > 0
                and iteration % cfg.TEST.EVAL_PERIOD == 0
                or iteration == max_iter):
                val_eval_results = inference_on_dataset(model, val_eval_data_loader,
                                                        val_evaluator)
                logger.info('val_eval_results {}', str(val_eval_results))
                results = val_eval_results.get('segm', None)
                if results is None:
                    results = val_eval_results.get('bbox', None)
                if results is not None and results.get('AP50',-1) > best_AP50:
                    best_AP50 = results['AP50']
                    logger.info('saving best results ({}), iter {}'.format(best_AP50, iteration))
                    checkpointer.save("best_AP50")

                flat_results = flatten_results(val_eval_results)
                storage.put_scalars(**flat_results)
                comm.synchronize()

            if iteration - start_iter > 5 and (iteration % 20 == 0):
                for writer in val_writers:
                    writer.write()
                elapsed = time.time() - start_time
                time_per_iter = elapsed / (iteration - start_iter)
                time_left = time_per_iter * (max_iter - iteration)
                logger.info("ETA: {}".format(str(datetime.timedelta(seconds=time_left))))
Example #20
0
def do_train(cfg_source, cfg_target, model, resume=False):

    model.train()
    print(model)

    optimizer = build_optimizer(cfg_source, model)
    scheduler = build_lr_scheduler(cfg_source, optimizer)

    checkpointer = DetectionCheckpointer(model,
                                         cfg_source.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)
    start_iter = (checkpointer.resume_or_load(
        cfg_source.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    max_iter = cfg_source.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(
        checkpointer, cfg_source.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg_source.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg_source.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    i = 1
    max_epoch = 41.27  # max iter / min(data_len(data_source, data_target))
    current_epoch = 0
    data_len = 1502

    alpha3 = 0
    alpha4 = 0
    alpha5 = 0

    data_loader_source = build_detection_train_loader(cfg_source)
    data_loader_target = build_detection_train_loader(cfg_target)
    logger.info("Starting training from iteration {}".format(start_iter))

    with EventStorage(start_iter) as storage:
        for data_source, data_target, iteration in zip(
                data_loader_source, data_loader_target,
                range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            if (iteration % data_len) == 0:
                current_epoch += 1
                i = 1

            p = float(i + current_epoch * data_len) / max_epoch / data_len
            alpha = 2. / (1. + np.exp(-10 * p)) - 1
            i += 1

            alpha3 = alpha
            alpha4 = alpha
            alpha5 = alpha

            if alpha3 > 0.5:
                alpha3 = 0.5

            if alpha4 > 0.5:
                alpha4 = 0.5

            if alpha5 > 0.1:
                alpha5 = 0.1

            loss_dict = model(data_source, False, alpha3, alpha4, alpha5)
            loss_dict_target = model(data_target, True, alpha3, alpha4, alpha5)
            loss_dict["loss_r3"] += loss_dict_target["loss_r3"]
            loss_dict["loss_r4"] += loss_dict_target["loss_r4"]
            loss_dict["loss_r5"] += loss_dict_target["loss_r5"]

            loss_dict["loss_r3"] *= 0.5
            loss_dict["loss_r4"] *= 0.5
            loss_dict["loss_r5"] *= 0.5

            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if iteration - start_iter > 5 and (iteration % 20 == 0
                                               or iteration == max_iter):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)
def do_train(cfg, model, resume=False):
    # Set model to training mode
    model.train()
    # Create optimizer from config file (returns torch.nn.optimizer.Optimizer)
    optimizer = build_optimizer(cfg, model)
    # Create scheduler for learning rate (returns torch.optim.lr._LR_scheduler)
    scheduler = build_lr_scheduler(cfg, optimizer)
    print(f"Scheduler: {scheduler}")

    # Create checkpointer
    checkpointer = DetectionCheckpointer(model,
                                         save_dir=cfg.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)

    # Create start iteration (refernces checkpointer) - https://detectron2.readthedocs.io/modules/checkpoint.html#detectron2.checkpoint.Checkpointer.resume_or_load
    start_iter = (
        # This can be 0
        checkpointer.resume_or_load(
            cfg.MODEL.
            WEIGHTS,  # Use predefined model weights (pretrained model)
            resume=resume).get("iteration", -1) + 1)
    # Set max number of iterations
    max_iter = cfg.SOLVER.MAX_ITER

    # Create periodiccheckpoint
    periodic_checkpointer = PeriodicCheckpointer(
        checkpointer=checkpointer,
        # How often to make checkpoints?
        period=cfg.SOLVER.CHECKPOINT_PERIOD,
        max_iter=max_iter)

    # Create writers (for saving checkpoints?)
    writers = ([
        # Print out common metrics such as iteration time, ETA, memory, all losses, learning rate
        CommonMetricPrinter(max_iter=max_iter),
        # Write scalars to a JSON file such as loss values, time and more
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        # Write all scalars such as loss values to a TensorBoard file for easy visualization
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    ### Original note from script: ###
    # compared to "train_net.py", we do not support accurate timing and precise BN
    # here, because they are not trivial to implement

    # Build a training data loader based off the training dataset name in the config
    data_loader = build_detection_train_loader(cfg)

    # Start logging
    logger.info("Starting training from iteration {}".format(start_iter))

    # Store events
    with EventStorage(start_iter) as storage:
        # Loop through zipped data loader and iteration
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step(
            )  # update stroage with step - https://detectron2.readthedocs.io/modules/utils.html#detectron2.utils.events.EventStorage.step

            # Create loss dictionary by trying to model data
            loss_dict = model(data)
            losses = sum(loss_dict.values())
            # Are losses infinite? If so, something is wrong
            assert torch.isfinite(losses).all(), loss_dict

            # TODO - Not quite sure what's happening here
            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            # Sum up losses
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            # # TODO: wandb.log()? log the losses
            # wandb.log({
            #         "Total loss": losses_reduced
            # })

            # Update storage
            if comm.is_main_process():
                # Store informate in storage - https://detectron2.readthedocs.io/modules/utils.html#detectron2.utils.events.EventStorage.put_scalars
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            # Start doing PyTorch things
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            # Add learning rate to storage information
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            # This is required for your learning rate to change!!!! (not having this meant my learning rate was staying at 0)
            scheduler.step()

            # Perform evaluation?
            if (cfg.TEST.EVAL_PERIOD > 0
                    and iteration % cfg.TEST.EVAL_PERIOD == 0
                    and iteration != max_iter):
                do_test(cfg, model)
                # TODO - compared to "train_net.py", the test results are not dumped to EventStorage
                comm.synchronize()

            # Log different metrics with writers
            if iteration - start_iter > 5 and (iteration % 20 == 0
                                               or iteration == max_iter):
                for writer in writers:
                    writer.write()

            # Update the periodic_checkpointer
            periodic_checkpointer.step(iteration)
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    # checkpointer = DetectionCheckpointer(
    #     model, cfg.OUTPUT_DIR,
    #     optimizer=optimizer,
    #     scheduler=scheduler
    # )
    #do not load checkpointer's optimizer and scheduler
    checkpointer = DetectionCheckpointer(model, cfg.OUTPUT_DIR)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)

    #model.load_state_dict(optimizer)

    max_iter = cfg.SOLVER.MAX_ITER

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement
    train_data_loader = build_detection_train_loader(
        cfg, mapper=PathwayDatasetMapper(cfg, True))

    # epoch_data_loader = build_detection_test_loader(cfg=cfg, dataset_name= cfg.DATASETS.TRAIN[0],
    #                                           mapper=PathwayDatasetMapper(cfg, True))

    val_data_loader = build_detection_validation_loader(
        cfg=cfg,
        dataset_name=cfg.DATASETS.TEST[0],
        mapper=PathwayDatasetMapper(cfg, False))

    if cfg.DATALOADER.ASPECT_RATIO_GROUPING:
        epoch_num = (train_data_loader.dataset.sampler._size //
                     cfg.SOLVER.IMS_PER_BATCH) + 1
    else:
        epoch_num = train_data_loader.dataset.sampler._size // cfg.SOLVER.IMS_PER_BATCH

    # periodic_checkpointer = PeriodicCheckpointer(
    #     checkpointer,
    #     #cfg.SOLVER.CHECKPOINT_PERIOD,
    #     epoch_num,
    #     max_iter=max_iter
    # )

    logger.info("Starting training from iteration {}".format(start_iter))
    loss_weights = {'loss_cls': 1, 'loss_box_reg': 1}
    with EventStorage(start_iter) as storage:
        loss_per_epoch = 0.0
        best_loss = 99999.0
        best_val_loss = 99999.0
        better_train = False
        better_val = False
        for data, iteration in zip(train_data_loader,
                                   range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss for loss in loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item() * loss_weights[k]
                for k, v in comm.reduce_dict(loss_dict).items()
            }

            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            #prevent gredient explosion
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            #if comm.is_main_process():
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            # if (
            #     # cfg.TEST.EVAL_PERIOD > 0
            #     # and
            #         iteration % epoch_num == 0
            #         #iteration % cfg.TEST.EVAL_PERIOD == 0
            #     and iteration != max_iter
            # ):
            #     do_test(cfg, model)
            #     # Compared to "train_net.py", the test results are not dumped to EventStorage
            #     comm.synchronize()

            loss_per_epoch += losses_reduced
            if iteration % epoch_num == 0 or iteration == max_iter:
                #one complete epoch
                epoch_loss = loss_per_epoch / epoch_num
                #do validation
                #epoch_loss, epoch_cls_loss, epoch_box_reg_loss = do_validation(epoch_data_loader, model, loss_weights)
                #val_loss, val_cls_loss, val_box_reg_loss = do_validation(val_data_loader, model, loss_weights)
                checkpointer.save("model_{:07d}".format(iteration),
                                  **{"iteration": iteration})
                # calculate epoch_loss and push to history cache
                #if comm.is_main_process():
                storage.put_scalar("epoch_loss",
                                   epoch_loss,
                                   smoothing_hint=False)
                # storage.put_scalar("epoch_cls_loss", epoch_cls_loss, smoothing_hint=False)
                # storage.put_scalar("epoch_box_reg_loss", epoch_box_reg_loss, smoothing_hint=False)
                # storage.put_scalar("val_loss", val_loss, smoothing_hint=False)
                # storage.put_scalar("val_cls_loss", val_cls_loss, smoothing_hint=False)
                # storage.put_scalar("val_box_reg_loss", val_box_reg_loss, smoothing_hint=False)

                for writer in writers:
                    writer.write()

                # only save improved checkpoints on epoch_loss
                # if best_loss > epoch_loss:
                #     best_loss = epoch_loss
                #     better_train = True
                # if best_val_loss > val_loss:
                #     best_val_loss = val_loss
                #     better_val = True
                #if better_val:
                #checkpointer.save("model_{:07d}".format(iteration),  **{"iteration": iteration})
                #comm.synchronize()
                #reset loss_per_epoch
                loss_per_epoch = 0.0
                # better_train = False
                # better_val = False
            del loss_dict, losses, losses_reduced, loss_dict_reduced
            torch.cuda.empty_cache()
def start_train(al_cfg, cfg, model, resume=False):
    early_stopping = EarlyStopping(patience=al_cfg.EARLY_STOP.PATIENCE,
                                   delta=al_cfg.EARLY_STOP.DELTA,
                                   verbose=True)
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model,
                                         cfg.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement
    data_loader = build_detection_train_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss for loss in loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if (cfg.TEST.EVAL_PERIOD > 0
                    and iteration % cfg.TEST.EVAL_PERIOD == 0
                    and iteration != max_iter):
                results = do_test(cfg, model)
                bbox_results = results['bbox']
                AP = bbox_results['AP']
                comm.synchronize()
                print('AP:', AP, '\tValue:', 1 - (AP / 100))
                early_stopping(1 - (AP / 100))
                storage.put_scalars(**bbox_results)
                if early_stopping.counter < 1:
                    checkpointer.save('model_final')

            if iteration - start_iter > 5 and (iteration % 20 == 0
                                               or iteration == max_iter):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)

            if early_stopping.early_stop:
                print("EARLY STOPPING INITIATED AT ITERATION:", iteration)
                # checkpointer.save('model_final')
                break
Example #24
0
class DefaultTrainer(TrainerBase):
    """
    A trainer with default training logic. It does the following:

    1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
       defined by the given config. Create a LR scheduler defined by the config.
    2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
       `resume_or_load` is called.
    3. Register a few common hooks defined by the config.

    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.

    The code of this class has been annotated about restrictive assumptions it makes.
    When they do not work for you, you're encouraged to:

    1. Overwrite methods of this class, OR:
    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
       nothing else. You can then add your own hooks if needed. OR:
    3. Write your own training loop similar to `tools/plain_train_net.py`.

    See the :doc:`/tutorials/training` tutorials for more details.

    Note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in detectron2.
    To obtain more stable behavior, write your own training logic with other public APIs.

    Examples:
    ::
        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
        trainer.train()

    Attributes:
        scheduler:
        checkpointer (DetectionCheckpointer):
        cfg (CfgNode):
    """
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("detectron2")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for d2
            setup_logger()
        cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())

        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)
        self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else
                         SimpleTrainer)(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
        a `last_checkpoint` file), resume from the file. Resuming means loading all
        available states (eg. optimizer and scheduler) and update iteration counter
        from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.

        Otherwise, this is considered as an independent training. The method will load model
        weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
        from iteration 0.

        Args:
            resume (bool): whether to do resume or not
        """
        checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS,
                                                      resume=resume)
        if resume and self.checkpointer.has_checkpoint():
            self.start_iter = checkpoint.get("iteration", -1) + 1
            # The checkpoint stores the training iteration that just finished, thus we start
            # at the next iteration (or iter zero if there's no checkpoint).
        if isinstance(self.model, DistributedDataParallel):
            # broadcast loaded data/model from the first rank, because other
            # machines may not have access to the checkpoint file
            if TORCH_VERSION >= (1, 7):
                self.model._sync_params_and_buffers()
            self.start_iter = comm.all_gather(self.start_iter)[0]

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(self.checkpointer,
                                           cfg.SOLVER.CHECKPOINT_PERIOD))

        # Changes for COSMOS model: we don't need to save the predictions
        # def test_and_save_results():
        #     self._last_eval_results = self.test(self.cfg, self.model)
        #     return self._last_eval_results
        #
        # # Do evaluation after checkpointer, because then if it fails,
        # # we can use the saved checkpoint to debug.
        # ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # Here the default print/log frequency of each writer is used.
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret

    def build_writers(self):
        """
        Build a list of writers to be used using :func:`default_writers()`.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        """
        return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)

    def train(self):
        """
        Run training.

        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_iter, self.max_iter)
        if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
            assert hasattr(self, "_last_eval_results"
                           ), "No evaluation results obtained during training!"
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    def run_step(self):
        self._trainer.iter = self.iter
        self._trainer.run_step()

    @classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:

        It now calls :func:`detectron2.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:

        It now calls :func:`detectron2.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        It now calls :func:`detectron2.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        """
        Returns:
            DatasetEvaluator or None

        It is not implemented by default.
        """
        raise NotImplementedError("""
If you want DefaultTrainer to automatically run evaluation,
please implement `build_evaluator()` in subclasses (see train_net.py for example).
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
""")

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                ``cfg.DATASETS.TEST``.

        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(
                cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                    len(cfg.DATASETS.TEST), len(evaluators))

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method.")
                    results[dataset_name] = {}
                    continue
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i)
                logger.info("Evaluation results for {} in csv format:".format(
                    dataset_name))
                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]
        return results

    @staticmethod
    def auto_scale_workers(cfg, num_workers: int):
        """
        When the config is defined for certain number of workers (according to
        ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
        workers currently in use, returns a new cfg where the total batch size
        is scaled so that the per-GPU batch size stays the same as the
        original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.

        Other config options are also scaled accordingly:
        * training steps and warmup steps are scaled inverse proportionally.
        * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.

        For example, with the original config like the following:

        .. code-block:: yaml

            IMS_PER_BATCH: 16
            BASE_LR: 0.1
            REFERENCE_WORLD_SIZE: 8
            MAX_ITER: 5000
            STEPS: (4000,)
            CHECKPOINT_PERIOD: 1000

        When this config is used on 16 GPUs instead of the reference number 8,
        calling this method will return a new config with:

        .. code-block:: yaml

            IMS_PER_BATCH: 32
            BASE_LR: 0.2
            REFERENCE_WORLD_SIZE: 16
            MAX_ITER: 2500
            STEPS: (2000,)
            CHECKPOINT_PERIOD: 500

        Note that both the original config and this new config can be trained on 16 GPUs.
        It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).

        Returns:
            CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
        """
        old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
        if old_world_size == 0 or old_world_size == num_workers:
            return cfg
        cfg = cfg.clone()
        frozen = cfg.is_frozen()
        cfg.defrost()

        assert (cfg.SOLVER.IMS_PER_BATCH %
                old_world_size == 0), "Invalid REFERENCE_WORLD_SIZE in config!"
        scale = num_workers / old_world_size
        bs = cfg.SOLVER.IMS_PER_BATCH = int(
            round(cfg.SOLVER.IMS_PER_BATCH * scale))
        lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
        max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER /
                                                   scale))
        warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(
            round(cfg.SOLVER.WARMUP_ITERS / scale))
        cfg.SOLVER.STEPS = tuple(
            int(round(s / scale)) for s in cfg.SOLVER.STEPS)
        cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
        cfg.SOLVER.CHECKPOINT_PERIOD = int(
            round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
        cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers  # maintain invariant
        logger = logging.getLogger(__name__)
        logger.info(
            f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
            f"max_iter={max_iter}, warmup={warmup_iter}.")

        if frozen:
            cfg.freeze()
        return cfg
Example #25
0
class ApexTrainer(SimpleTrainer):
    """
    A trainer with apex training logic. 
    """

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        logger = logging.getLogger("detectron2")
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
            setup_logger()
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )
            # model = DDP(model, delay_allreduce=True)
        super().__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        If `resume==True`, and last checkpoint exists, resume from it.

        Otherwise, load a model specified by the config.

        Args:
            resume (bool): whether to do resume or not
        """
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        self.start_iter = (
            self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
                "iteration", -1
            )
            + 1
        )

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers()))
        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.

        It is now implemented by:

        .. code-block:: python

            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]

        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Args:
            start_iter, max_iter (int): See docs above
        """
        logger = logging.getLogger(__name__)
        logger.info("Starting training from iteration {}".format(self.start_iter))

        with EventStorage(self.start_iter) as self.storage:
            try:
                self.before_train()
                for self.iter in range(self.start_iter, self.max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
            finally:
                self.after_train()
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    def run_step(self):
        """
        Implement the standard training logic described above.
        """
        assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
        start = time.perf_counter()
        """
        If your want to do something with the data, you can wrap the dataloader.
        """
        data = next(self._data_loader_iter)
        data_time = time.perf_counter() - start
        """
        If your want to do something with the losses, you can wrap the model.
        """
        loss_dict = self.model(data)
        losses = sum(loss for loss in loss_dict.values())
        self._detect_anomaly(losses, loss_dict)

        metrics_dict = loss_dict
        metrics_dict["data_time"] = data_time
        self._write_metrics(metrics_dict)
        """
        If you need accumulate gradients or something similar, you can
        wrap the optimizer with your custom `zero_grad()` method.
        """
        self.optimizer.zero_grad()
        with amp.scale_loss(losses, self.optimizer) as scaled_loss:
            scaled_loss.backward()
        """
        If you need gradient clipping/scaling or other processing, you can
        wrap the optimizer with your custom `step()` method.
        """
        self.optimizer.step()

    @classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:

        It now calls :func:`detectron2.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:

        It now calls :func:`detectron2.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        It now calls :func:`detectron2.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        """
        Returns:
            DatasetEvaluator

        It is not implemented by default.
        """
        raise NotImplementedError(
            "Please either implement `build_evaluator()` in subclasses, or pass "
            "your evaluator as arguments to `DefaultTrainer.test()`."
        )

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.

        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {}
                    continue
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i
                )
                logger.info("Evaluation results for {} in csv format:".format(dataset_name))
                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]
        return results
Example #26
0
def do_train(cfg, args):
    # fmt: off
    run_func = cfg.start.get('run_func', 'train_func')
    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH * comm.get_world_size()
    NUM_WORKERS = cfg.start.NUM_WORKERS
    dataset_mapper = cfg.start.dataset_mapper

    max_epoch = cfg.start.max_epoch
    checkpoint_period = cfg.start.checkpoint_period

    resume_cfg = get_attr_kwargs(cfg.start, 'resume_cfg', default=None)

    cfg.defrost()
    cfg.DATASETS.TRAIN = (dataset_name, )
    cfg.SOLVER.IMS_PER_BATCH = IMS_PER_BATCH
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.freeze()
    # fmt: on

    # build dataset
    mapper = build_dataset_mapper(dataset_mapper)
    data_loader = build_detection_train_loader(cfg, mapper=mapper)
    metadata = MetadataCatalog.get(dataset_name)
    num_samples = metadata.get('num_samples')
    iter_every_epoch = num_samples // IMS_PER_BATCH
    max_iter = iter_every_epoch * max_epoch

    model = build_trainer(cfg=cfg,
                          args=args,
                          iter_every_epoch=iter_every_epoch,
                          batch_size=IMS_PER_BATCH,
                          max_iter=max_iter,
                          metadata=metadata,
                          max_epoch=max_epoch,
                          data_loader=data_loader)
    model.train()

    # optimizer = build_optimizer(cfg, model)
    optims_dict = model.build_optimizer()
    scheduler = model.build_lr_scheduler()

    checkpointer = DetectionCheckpointer(model.get_saved_model(),
                                         cfg.OUTPUT_DIR, **optims_dict,
                                         **scheduler)
    if resume_cfg and resume_cfg.resume:
        resume_ckpt_dir = model._get_ckpt_path(
            ckpt_dir=resume_cfg.ckpt_dir,
            ckpt_epoch=resume_cfg.ckpt_epoch,
            iter_every_epoch=resume_cfg.iter_every_epoch)
        start_iter = (
            checkpointer.resume_or_load(resume_ckpt_dir).get("iteration", -1) +
            1)
        if get_attr_kwargs(resume_cfg, 'finetune', default=False):
            start_iter = 0
        model.after_resume()
    else:
        start_iter = 0

    if run_func != 'train_func':
        eval(f'model.{run_func}()')
        exit(0)

    checkpoint_period = eval(checkpoint_period,
                             dict(iter_every_epoch=iter_every_epoch))
    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 checkpoint_period,
                                                 max_iter=max_iter)
    logger.info("Starting training from iteration {}".format(start_iter))
    # modelarts_utils.modelarts_sync_results(args=myargs.args, myargs=myargs, join=True, end=False)
    with EventStorage(start_iter) as storage:
        pbar = zip(data_loader, range(start_iter, max_iter))
        if comm.is_main_process():
            pbar = tqdm.tqdm(
                pbar,
                desc=f'do_train, {args.tl_time_str}, '
                f'iters {iter_every_epoch} * bs {IMS_PER_BATCH} = '
                f'imgs {iter_every_epoch*IMS_PER_BATCH}',
                initial=start_iter,
                total=max_iter)

        for data, iteration in pbar:
            comm.synchronize()
            iteration = iteration + 1
            storage.step()

            model.train_func(data, iteration - 1, pbar=pbar)

            periodic_checkpointer.step(iteration)
            pass
    # modelarts_utils.modelarts_sync_results(args=myargs.args, myargs=myargs, join=True, end=True)
    comm.synchronize()
class DefaultTrainer(SimpleTrainer):
    """
    A trainer with default training logic. Compared to `SimpleTrainer`, it
    contains the following logic in addition:

    1. Create model, optimizer, scheduler, dataloader from the given config.
    2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
    3. Register a few common hooks.

    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.

    The code of this class has been annotated about restrictive assumptions it mades.
    When they do not work for you, you're encouraged to write your own training logic.

    Also note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in detectron2.
    To obtain more stable behavior, write your own training logic with other public APIs.

    Attributes:
        scheduler:
        checkpointer (DetectionCheckpointer):
        cfg (CfgNode):
    """

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )
        super().__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        If `resume==True`, and last checkpoint exists, resume from it.

        Otherwise, load a model specified by the config.

        Args:
            resume (bool): whether to do resume or not
        """
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        self.start_iter = (
            self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
                "iteration", -1
            )
            + 1
        )

    def build_hooks(self):
        """
        Build a list of default hooks.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers()))
        return ret

    def build_writers(self):
        """
        Build a list of default writers, that write metrics to the screen,
        a json file, and a tensorboard event file respectively.

        Returns:
            list[Writer]: a list of objects that have a ``.write`` method.
        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.

        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_iter, self.max_iter)
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        """
        return build_detection_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable
        """
        return build_detection_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        """
        Returns:
            DatasetEvaluator
        """
        raise NotImplementedError

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.

        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            evaluator = (
                evaluators[idx]
                if evaluators is not None
                else cls.build_evaluator(cfg, dataset_name)
            )
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i
                )
                logger.info("Evaluation results for {} in csv format:".format(dataset_name))
                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]
        return results
Example #28
0
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(
        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
    )
    
    start_iter = (
        checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
    )
    max_iter = cfg.SOLVER.MAX_ITER

    writers = (
        [
            CommonMetricPrinter(max_iter),
            JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(cfg.OUTPUT_DIR),
        ]
        if comm.is_main_process()
        else []
    )
    min_size = cfg.INPUT.MIN_SIZE_TRAIN 
    max_size = cfg.INPUT.MAX_SIZE_TRAIN, 
    sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
    data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg,
                                                                        is_train=True,
                                                                        augmentations=[
                                                                        T.ResizeShortestEdge(min_size, max_size, sample_style),
                                                                        T.RandomApply(T.RandomFlip(prob = 1, vertical = False), prob = 0.5),
                                                                        T.RandomApply(T.RandomRotation(angle = [180], sample_style = 'choice'), prob = 0.1),
                                                                        T.RandomApply(T.RandomRotation(angle = [-10,10], sample_style = 'range'), prob = 0.9),
                                                                        T.RandomApply(T.RandomBrightness(0.5,1.5), prob = 0.5),
                                                                        T.RandomApply(T.RandomContrast(0.5,1.5), prob = 0.5)                                                             
                                                                        ]))
    best_model_weight = copy.deepcopy(model.state_dict())
    best_val_loss = None
    data_val_loader = build_detection_test_loader(cfg,
                                                  cfg.DATASETS.TEST[0],
                                                  mapper = DatasetMapper(cfg, True))
    logger.info("Starting training from iteration {}".format(start_iter))

    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration += 1
            start = time.time()
            storage.step()

            loss_dict = model(data)
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
            scheduler.step()

            if (
                cfg.TEST.EVAL_PERIOD > 0
                and iteration % cfg.TEST.EVAL_PERIOD == 0
                and iteration != max_iter
            ):
                logger.setLevel(logging.CRITICAL)
                print('validating')
                val_total_loss = do_val_monitor(cfg, model, data_val_loader)
                logger.setLevel(logging.DEBUG)
                logger.info(f"validation loss of iteration {iteration}th: {val_total_loss}")
                storage.put_scalar(name = 'val_total_loss', value = val_total_loss)
                
                if best_val_loss is None or val_total_loss < best_val_loss:
                  best_val_loss = val_total_loss
                  best_model_weight = copy.deepcopy(model.state_dict())

                comm.synchronize()
            
            # สร้าง checkpointer เพิ่มให้ save best model โดยดูจาก val loss
            if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter):
                for writer in writers:
                    writer.write()
            
    model.load_state_dict(best_model_weight)
    experiment_name = os.getenv('MLFLOW_EXPERIMENT_NAME')
    checkpointer.save(f'model_{experiment_name}')
    return model
Example #29
0
class DefaultTrainer(SimpleTrainer):
    """
    A trainer with default training logic. Compared to `SimpleTrainer`, it
    contains the following logic in addition:

    1. Create model, optimizer, scheduler, dataloader from the given config.
    2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
       `resume_or_load` is called.
    3. Register a few common hooks.

    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.

    The code of this class has been annotated about restrictive assumptions it mades.
    When they do not work for you, you're encouraged to:

    1. Overwrite methods of this class, OR:
    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
       nothing else. You can then add your own hooks if needed. OR:
    3. Write your own training loop similar to `tools/plain_train_net.py`.

    Also note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in detectron2.
    To obtain more stable behavior, write your own training logic with other public APIs.

    Examples:

    .. code-block:: python

        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
        trainer.train()

    Attributes:
        scheduler:
        checkpointer (DetectionCheckpointer):
        cfg (CfgNode):
    """

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        logger = logging.getLogger("detectron2")
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
            setup_logger()
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )
        super().__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        If `resume==True`, and last checkpoint exists, resume from it and load all
        checkpointables (eg. optimizer and scheduler).

        Otherwise, load the model specified by the config (skip all checkpointables).

        Args:
            resume (bool): whether to do resume or not
        """
        checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
        self.start_iter = checkpoint.get("iteration", -1) if resume else -1
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        self.start_iter += 1

    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.

        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret

    def build_writers(self):
        """
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.

        It is now implemented by:

        .. code-block:: python

            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]

        """
        # Here the default print/log frequency of each writer is used.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.

        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        super().train(self.start_iter, self.max_iter)
        if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
            assert hasattr(
                self, "_last_eval_results"
            ), "No evaluation results obtained during training!"
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:

        It now calls :func:`detectron2.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Returns:
            torch.optim.Optimizer:

        It now calls :func:`detectron2.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        It now calls :func:`detectron2.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        """
        Returns:
            DatasetEvaluator or None

        It is not implemented by default.
        """
        raise NotImplementedError(
            """
If you want DefaultTrainer to automatically run evaluation,
please implement `build_evaluator()` in subclasses (see train_net.py for example).
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
"""
        )

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.

        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {}
                    continue
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i
                )
                logger.info("Evaluation results for {} in csv format:".format(dataset_name))
                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]
        return results
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)
    checkpointer = DetectionCheckpointer(model,
                                         cfg.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)

    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get('iteration', -1) + 1)

    max_iter = cfg.SOLVER.MAX_ITER
    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, 'metric.json')),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    data_loader = build_detection_train_loader(cfg)
    logger.info(" Starting training from iteration {}".format(start_iter))

    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss for loss in loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if (cfg.TEST.EVAL_PERIOD > 0
                    and iteration % cfg.TEST_EVAL_PERIOC == 0
                    and iteration != max_iter):
                do_test(cfg, model)
                comm.synchronize()

            if iteration - start_iter > 5 and (iteration % 20 == 0
                                               or iteration == max_iter):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)