コード例 #1
0
def test(cfg, model, distributed):
    if distributed:
        model = model.module
    torch.cuda.empty_cache()  # TODO check if it helps
    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg, is_train=True, is_distributed=distributed)
    meters = TensorboardLogger(
        log_dir=cfg.TENSORBOARD_EXPERIMENT,
        delimiter="  ")
    for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            cfg,
            data_loader_val,
            dataset_name=dataset_name,
            meters=meters,
            iou_types=iou_types,
            box_only=cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
コード例 #2
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if cfg.USE_TENSORBOARD_LOGS:
        meters = TensorboardLogger(
            log_dir=os.path.join(output_dir, 'tensorboard_logs'),
            start_iter=arguments['iteration'],
            delimiter="  ",
        )
    else:
        meters = MetricLogger(delimiter="  ")

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        meters,
    )

    return model
コード例 #3
0
def do_train(model,
             data_loader,
             optimizer,
             scheduler,
             checkpointer,
             device,
             checkpoint_period,
             arguments,
             tb_log_dir,
             use_tensorboard=False):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = TensorboardLogger(log_dir=tb_log_dir,
                               delimiter="  ") \
             if use_tensorboard else MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]
        loss_dict, preds = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(iteration, loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()

        accum_grad = 0
        for p in list(filter(lambda p: p.grad is not None,
                             model.parameters())):
            accum_grad += p.grad.data.norm(2).item()

        if iteration > 500 and accum_grad > 200:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 200)
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(iteration, time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == (max_iter - 1):

            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if iteration % checkpoint_period == 0:
            meters.update_image(iteration, images.tensors[0], preds[0],
                                targets[0])
            checkpointer.save("model_{:07d}".format(iteration), **arguments)

    checkpointer.save("model_{:07d}".format(iteration), **arguments)
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
コード例 #4
0
def train(cfg,
          local_rank,
          distributed,
          model_config=None,
          use_tensorboard=True):
    model = build_detection_model(cfg, model_config)
    if get_rank() == 0:
        if 'search' in cfg.MODEL.BACKBONE.CONV_BODY:
            print('backbone search space:', blocks_key)
        else:
            print('backbone:', cfg.MODEL.BACKBONE)
        if 'search' in cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR or 'search' in cfg.MODEL.SEG_BRANCH.SEGMENT_BRANCH:
            print('head search space:', head_ss_keys)
        else:
            print('head:', cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR,
                  cfg.MODEL.SEG_BRANCH.SEGMENT_BRANCH)
        if 'search' in cfg.MODEL.INTER_MODULE.NAME:
            print('inter search space:', inter_ss_keys)
        else:
            print('inter:', cfg.MODEL.INTER_MODULE.NAME)
        print(model)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer, lr_dict = make_optimizer(cfg, model)
    if get_rank() == 0:
        for item in lr_dict:
            print(item)
    scheduler = make_lr_scheduler(cfg, optimizer)

    # Initialize mixed-precision training
    if not ('search' in cfg.MODEL.BACKBONE.CONV_BODY
            or 'search' in cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR
            or 'search' in cfg.MODEL.SEG_BRANCH.SEGMENT_BRANCH):
        use_mixed_precision = cfg.DTYPE == "float16"
        amp_opt_level = 'O1' if use_mixed_precision else 'O0'
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=amp_opt_level)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True)

    # if 'search' in cfg.MODEL.BACKBONE.CONV_BODY:
    #     def forward_hook(module: Module, inp: (Tensor,)):
    #         if module.weight is not None:
    #             module.weight.requires_grad = True
    #         if module.bias is not None:
    #             module.bias.requires_grad = True

    #     all_modules = (nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.GroupNorm, ) # group norm更新!!
    #     for m in model.modules():
    #         if isinstance(m, all_modules):
    #             m.register_forward_pre_hook(forward_hook)

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    test_period = cfg.SOLVER.TEST_PERIOD
    if test_period > 0:
        data_loader_val = make_data_loader(cfg,
                                           is_train=False,
                                           is_distributed=distributed,
                                           is_for_period=True)
    else:
        data_loader_val = None

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if use_tensorboard:
        meters = TensorboardLogger(cfg=cfg,
                                   log_dir=cfg.TENSORBOARD_EXPERIMENT,
                                   start_iter=arguments['iteration'],
                                   delimiter="  ")
    else:
        meters = MetricLogger(delimiter="  ")

    do_train(
        cfg,
        model,
        data_loader,
        data_loader_val,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        test_period,
        arguments,
        meters,
    )

    return model
コード例 #5
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default=
        "/homes/maskrcnn/configs/caffe2/e2e_mask_rcnn_R_101_FPN_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)

    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=distributed)
    meters = TensorboardLogger(log_dir=cfg.TENSORBOARD_EXPERIMENT,
                               delimiter="  ")

    _ = checkpointer.load(cfg.MODEL.WEIGHT)
    for output_folder, dataset_name, data_loader_val in zip(
            output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            cfg,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            meters=meters,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()
コード例 #6
0
def train(cfg, local_rank, distributed, use_tensorboard=False):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    test_period = cfg.SOLVER.TEST_PERIOD
    if test_period > 0:
        data_loader_val = make_data_loader(cfg, is_train=False, is_distributed=distributed, is_for_period=True)
    else:
        data_loader_val = None

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    
    if use_tensorboard:
        meters = TensorboardLogger(
            log_dir=cfg.TENSORBOARD_EXPERIMENT,
            stage = 'train',
            start_iter=arguments['iteration'],
            delimiter="  ")
        meters_val = TensorboardLogger(
            log_dir=cfg.TENSORBOARD_EXPERIMENT,
            stage = 'val',
            start_iter=arguments['iteration'],
            delimiter="  ")
    else:
        meters = MetricLogger(delimiter="  ")
        meters_val = MetricLogger(delimiter="  ")

    do_train(
        cfg,
        model,
        data_loader,
        data_loader_val,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        test_period,
        arguments,
        meters,
        meters_val,
    )

    return model
コード例 #7
0
def do_train(
    model,
    data_loader,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    arguments,
):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    # meters = MetricLogger(delimiter="  ")
    if "tb_log_dir" in arguments and "tb_exp_name" in arguments:
        meters = TensorboardLogger(log_dir=arguments["tb_log_dir"],
                                   exp_name=arguments["tb_exp_name"],
                                   start_iter=arguments['iteration'],
                                   delimiter="  ")
    else:
        meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
コード例 #8
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "--use-tensorboard",
        dest="use_tensorboard",
        help="Use tensorboardX logger (Requires tensorboardX installed)",
        action="store_true",
        default=False)

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    # print("num_classes", cfg)

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    if args.use_tensorboard:
        extra_name = cfg.OUTPUT_DIR.split('/')[-1]
        meters = TensorboardLogger(log_dir=cfg.TENSORBOARD_EXPERIMENT,
                                   start_iter=0,
                                   delimiter="  ",
                                   extra_name=extra_name)
    else:
        meters = MetricLogger(delimiter="  ")

    model = train(cfg, args.local_rank, args.distributed, meters)

    if not args.skip_test:
        run_test(cfg, model, args.distributed, meters)
コード例 #9
0
ファイル: trainer.py プロジェクト: witwitchayakarn/6DVNET
def do_train(model,
             data_loader,
             optimizer,
             scheduler,
             checkpointer,
             device,
             checkpoint_period,
             arguments,
             snapshot,
             tb_log_dir,
             tb_exp_name,
             use_tensorboard=False):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = TensorboardLogger(log_dir=tb_log_dir,
                               exp_name=tb_exp_name,
                               start_iter=arguments['iteration'],
                               delimiter="  ") \
        if use_tensorboard else MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets,
                    img_idx) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)
        # we only sum up the loss but no other metrics
        losses = sum([
            v for (k, v) in loss_dict.items() if k.split('/')[-1][:4] == 'loss'
        ])
        losses.backward()
        optimizer.step()
        scheduler.step()

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum([
            v for (k, v) in loss_dict_reduced.items()
            if k.split('/')[-1][:4] == 'loss'
        ])

        # Prepend the key name so as to allow better visualisation in tensorboard
        loss_dict_reduced['detector_losses/total_loss'] = losses_reduced
        meters.update(**loss_dict_reduced)
        meters.update(lr=optimizer.param_groups[-1]['lr'])

        batch_time = time.time() - end
        end = time.time()
        process_time = {
            'time/batch_time': batch_time,
            'time/data_time': data_time
        }
        meters.update(**process_time)

        eta_seconds = meters.meters['time/batch_time'].global_avg * (max_iter -
                                                                     iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % snapshot == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "\n \t eta: {eta}",
                    "iter: {iter}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                    "\n \t {meters}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                    meters=str(meters),
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
コード例 #10
0
def train(cfg, local_rank, distributed, use_tensorboard=False):
    model = build_detection_model(cfg)
    #     print(model)

    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    #     return
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if use_tensorboard:
        meters = TensorboardLogger(log_dir=cfg.TENSORBOARD_EXPERIMENT,
                                   start_iter=arguments['iteration'],
                                   delimiter="  ")
    else:
        meters = MetricLogger(delimiter="  ")

    if cfg.MODEL.DOMAIN_ADAPTATION_ON:
        source_data_loader = make_data_loader(
            cfg,
            is_train=True,
            is_source=True,
            is_distributed=distributed,
            start_iter=arguments["iteration"],
        )
        target_data_loader = make_data_loader(
            cfg,
            is_train=True,
            is_source=False,
            is_distributed=distributed,
            start_iter=arguments["iteration"],
        )

        do_da_train(model, source_data_loader, target_data_loader, optimizer,
                    scheduler, checkpointer, device, checkpoint_period,
                    arguments, cfg, meters)
    else:
        data_loader = make_data_loader(
            cfg,
            is_train=True,
            is_distributed=distributed,
            start_iter=arguments["iteration"],
        )

        do_train(model, data_loader, optimizer, scheduler, checkpointer,
                 device, checkpoint_period, arguments, meters)

    return model
コード例 #11
0
def train(cfg, local_rank, distributed, use_tensorboard=False):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    # load_scheduler_only_epoch will prefer the scheduler specified in the
    # config rather than the one in the checkpoint, and will load only the
    # last_epoch from the checkpoint.
    extra_checkpoint_data = checkpointer.load(
        cfg.MODEL.WEIGHT,
        load_model_only=cfg.MODEL.LOAD_ONLY_WEIGHTS,
        load_scheduler_only_epoch=True)
    if not cfg.MODEL.LOAD_ONLY_WEIGHTS:
        arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    if use_tensorboard:
        meters = TensorboardLogger(log_dir=output_dir,
                                   exp_name=cfg.TENSORBOARD_EXP_NAME,
                                   start_iter=arguments['iteration'],
                                   delimiter="  ")
    else:
        meters = MetricLogger(delimiter="  ")

    do_train(model, data_loader, optimizer, scheduler, checkpointer, device,
             checkpoint_period, arguments, meters)

    return model