Ejemplo n.º 1
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
def train(cfg, local_rank, distributed):
    BACKBONE = build_backbone_model(cfg)
    head = build_head_model(cfg)
    model = face_trainer(BACKBONE,head)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)


    #### init transforms #####
    RGB_MEAN = [0.5, 0.5, 0.5]
    RGB_STD = [0.5, 0.5, 0.5]
    transforms = T.Compose(
        [
            T.RandomCrop( (cfg.INPUT.MIN_SIZE_TRAIN[0], cfg.INPUT.MIN_SIZE_TRAIN[1]) ),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=RGB_MEAN, std=RGB_STD),
        ]
    )
    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
        transforms=transforms,
    )


    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_face_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
Ejemplo n.º 3
0
def do_train(
    cfg,
    total_model,
    data_loader,
    data_loader_val,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    test_period,
    arguments,
    args,
):
    if len(total_model) > 1:
        model = total_model[1]
        t_model = total_model[0]
    else:
        model = total_model[0]
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()

    start_training_time = time.time()
    end = time.time()

    iou_types = ("bbox", )
    if cfg[0].MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg[0].MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    dataset_names = cfg[0].DATASETS.TEST

    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
        if not pytorch_1_1_0_or_later:
            scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict, features_dict = model(images, targets)
        if len(total_model) > 1:
            with torch.no_grad():
                t_loss_dict, t_features_dict = t_model(images, targets)
            # with torch.no_grad():
            #     # teacher_model = t_model
            #     t_weight = torch.load('./weights/centermask-V-19-eSE-FPN-ms-3x.pth')
            #     t_weight = t_weight['model']
            #     new_tweight = OrderedDict()
            #     for k, v in t_weight.items():
            #         name = k[7:]  # remove `module.`
            #         new_tweight[name] = v
            #     t_model.load_state_dict(new_tweight)
            #     t_loss_dict, t_features_dict = t_model(images, targets)

        if args.loss_head:

            loss_regression = new_box_loss(t_loss_dict['loss_reg'],
                                           loss_dict['loss_reg'])
            loss_center = new_center_loss(t_loss_dict['loss_centerness'],
                                          loss_dict['loss_centerness'])
            mode = 'KL'  # mode = 'KL' or 'cross-entropy'
            loss_pixel_wise = pixel_wise_loss(features_dict['box_cls'],
                                              t_features_dict['box_cls'], mode)
            loss_head = (loss_regression + loss_center + loss_pixel_wise)
            loss_dict.setdefault('loss_head', loss_head)
            del loss_dict['loss_reg']
            del loss_dict['loss_centerness']

        if iteration > cfg[0].SOLVER.WARMUP_ITERS:
            if args.loss_correlation:
                correlation = True
                loss_corr = get_feature(t_model, model, images, targets,
                                        correlation)
                loss_dict.setdefault('loss_corr', loss_corr)
            if args.loss_featuremap:
                correlation = False
                loss_featuremap = get_feature(t_model, model, images, targets,
                                              correlation)
                loss_dict.setdefault('loss_featuremap', loss_featuremap)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if pytorch_1_1_0_or_later:
            scheduler.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if data_loader_val is not None and test_period > 0 and iteration % test_period == 0 and iteration != 0:
            meters_val = MetricLogger(delimiter="  ")
            synchronize()
            _ = inference(  # The result can be used for additional logging, e. g. for TensorBoard
                model,
                # The method changes the segmentation mask format in a data loader,
                # so every time a new data loader is created:
                make_data_loader(cfg[0],
                                 is_train=False,
                                 is_distributed=(get_world_size() > 1),
                                 is_for_period=True),
                dataset_name="[Validation]",
                iou_types=iou_types,
                box_only=False
                if cfg[0].MODEL.MASK_ON else cfg[0].MODEL.RPN_ONLY,
                device=cfg[0].MODEL.DEVICE,
                expected_results=cfg[0].TEST.EXPECTED_RESULTS,
                expected_results_sigma_tol=cfg[0].TEST.
                EXPECTED_RESULTS_SIGMA_TOL,
                output_folder=None,
            )
            synchronize()
            model.train()
            with torch.no_grad():
                # Should be one image for each GPU:
                for iteration_val, (images_val, targets_val,
                                    _) in enumerate(tqdm(data_loader_val)):
                    images_val = images_val.to(device)
                    targets_val = [target.to(device) for target in targets_val]
                    loss_dict = model(images_val, targets_val)
                    if len(loss_dict) > 1:
                        loss_dict = loss_dict[0]
                    else:
                        loss_dict = loss_dict
                    losses = sum(loss for loss in loss_dict.values())
                    loss_dict_reduced = reduce_loss_dict(loss_dict)
                    losses_reduced = sum(
                        loss for loss in loss_dict_reduced.values())
                    meters_val.update(loss=losses_reduced, **loss_dict_reduced)
            synchronize()
            logger.info(
                meters_val.delimiter.join([
                    "[Validation]: ",
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters_val),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Ejemplo n.º 4
0
def train(total_cfg, local_rank, distributed):
    total_model = []
    for i in reversed(range(len(total_cfg))):
        model = build_detection_model(total_cfg[i])
        device = torch.device(total_cfg[i].MODEL.DEVICE)
        model.to(device)
        if total_cfg[i].MODEL.USE_SYNCBN:
            assert is_pytorch_1_1_0_or_later(), \
                "SyncBatchNorm is only available in pytorch >= 1.1.0"
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        optimizer = make_optimizer(total_cfg[i], model)
        scheduler = make_lr_scheduler(total_cfg[i], optimizer)

        if distributed:                                                     
            model = torch.nn.parallel.DistributedDataParallel(              
                model, device_ids=[local_rank], output_device=local_rank,   
                # this should be removed if we update BatchNorm stats       
                broadcast_buffers=False, )

        arguments = {}
        arguments["iteration"] = 0

        output_dir = total_cfg[i].OUTPUT_DIR

        save_to_disk = get_rank() == 0
        checkpointer = DetectronCheckpointer(
            total_cfg[i], model, optimizer, scheduler, output_dir, save_to_disk
        )
        extra_checkpoint_data = checkpointer.load(total_cfg[i].MODEL.WEIGHT)
        if i == 0:
            arguments.update(extra_checkpoint_data)
        total_model.append(model)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    test_period = total_cfg[0].SOLVER.TEST_PERIOD
    if test_period > 0:
        data_loader_val = make_data_loader(total_cfg[0], is_train=False, is_distributed=distributed, is_for_period=True)
    else:
        data_loader_val = None

    checkpoint_period = total_cfg[0].SOLVER.CHECKPOINT_PERIOD
    if len(total_model)>1:
        params = sum([np.prod(p.size()) for p in total_model[1].parameters()])
        print('Number of Parameters:{:5f}M'.format(params / 1e6))
        params = sum([np.prod(p.size()) for p in total_model[0].parameters()])
        print('teacher_model Number of Parameters:{:5f}M'.format(params / 1e6))
    else:
        params = sum([np.prod(p.size()) for p in total_model[0].parameters()])
        print('Number of Parameters:{:5f}M'.format(params / 1e6))

    do_train(
        total_cfg,
        total_model,
        data_loader,
        data_loader_val,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        test_period,
        arguments,
        args,
    )

    return total_model[1]
Ejemplo n.º 5
0
def do_train(
    cfg,
    model,
    data_loader,
    data_loader_val,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    test_period,
    arguments,
):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()

    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints", )
    dataset_names = cfg.DATASETS.TEST

    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
        if not pytorch_1_1_0_or_later:
            scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if pytorch_1_1_0_or_later:
            scheduler.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if data_loader_val is not None and test_period > 0 and iteration % test_period == 0 and iteration != 0:
            meters_val = MetricLogger(delimiter="  ")
            synchronize()
            _ = inference(  # The result can be used for additional logging, e. g. for TensorBoard
                model,
                # The method changes the segmentation mask format in a data loader,
                # so every time a new data loader is created:
                make_data_loader(cfg,
                                 is_train=False,
                                 is_distributed=(get_world_size() > 1),
                                 is_for_period=True),
                dataset_name="[Validation]",
                iou_types=iou_types,
                box_only=False if cfg.MODEL.MASK_ON else cfg.MODEL.RPN_ONLY,
                device=cfg.MODEL.DEVICE,
                expected_results=cfg.TEST.EXPECTED_RESULTS,
                expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
                output_folder=None,
            )
            synchronize()
            model.train()
            with torch.no_grad():
                # Should be one image for each GPU:
                for iteration_val, (images_val, targets_val,
                                    _) in enumerate(tqdm(data_loader_val)):
                    images_val = images_val.to(device)
                    targets_val = [target.to(device) for target in targets_val]
                    loss_dict = model(images_val, targets_val)
                    losses = sum(loss for loss in loss_dict.values())
                    loss_dict_reduced = reduce_loss_dict(loss_dict)
                    losses_reduced = sum(
                        loss for loss in loss_dict_reduced.values())
                    meters_val.update(loss=losses_reduced, **loss_dict_reduced)
            synchronize()
            logger.info(
                meters_val.delimiter.join([
                    "[Validation]: ",
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters_val),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Ejemplo n.º 6
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    # pdb.set_trace()

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    # pdb.set_trace()
    # (Pdb) optimizer.param_groups[0]["lr"]
    # 0.0016666666666666666


    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    # pdb.set_trace()
    # (Pdb) optimizer.param_groups[0]["lr"]
    # 0.0016666666666666666

    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )

    # pdb.set_trace()
    # (Pdb) cfg.MODEL.WEIGHT
    # 'coco_P2_8.pth'
    # (Pdb) optimizer.param_groups[0]["lr"]
    # 0.0016666666666666666
    
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    # pdb.set_trace()
    # (Pdb) optimizer.param_groups[0]["lr"]
    # 0.00010000000000000002


    # pdb.set_trace()
    # (Pdb) extra_checkpoint_data
    # {'iteration': 80000}
    # (Pdb) arguments
    # {'iteration': 80000}

    # coco_pretrained_P2, start=8000 => start=0
    arguments["iteration"] = 0

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    # pdb.set_trace()
    # (Pdb) cfg.SOLVER.CHECKPOINT_PERIOD
    # 2500

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    # pdb.set_trace()
    # optimizer.param_groups[0]["lr"]
    # 0.00010000000000000002

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
Ejemplo n.º 7
0
def train(cfg, local_rank, distributed, meters):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if 'rpn' in cfg.FEW_SHOT.UNTRAINED_KEYWORD:
        for param in model.rpn.parameters():
            param.requires_grad = False
    if 'backbone' in cfg.FEW_SHOT.UNTRAINED_KEYWORD:
        for param in model.backbone.parameters():
            param.requires_grad = False

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    print(cfg.MODEL.WEIGHT)
    if cfg.MODEL.FSS_LOAD:
        cfg.defrost()
        cfg.FEW_SHOT.UNLOAD_KEYWORD = ('rpn', )
        cfg.freeze()
        extra_checkpoint_data = checkpointer.load(cfg.MODEL.FSS_WEIGHT)
        cfg.defrost()
        cfg.FEW_SHOT.UNLOAD_KEYWORD = ('backbone', 'roi_head')
        cfg.freeze()
        extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    else:
        extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    if not cfg.FEW_SHOT.RESUME:
        arguments["iteration"] = 0

    if isinstance(meters, TensorboardLogger):
        meters.iteration = arguments["iteration"]

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    # torch.cuda.empty_cache()
    # print(model)

    do_train(
        cfg,
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        meters,
    )

    return model
Ejemplo n.º 8
0
def do_train(
    model,
    data_loader,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    arguments,
):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()

    # pdb.set_trace()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
        if not pytorch_1_1_0_or_later:
            scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        # pdb.set_trace()

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if pytorch_1_1_0_or_later:
            scheduler.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        # pdb.set_trace()
        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Ejemplo n.º 9
0
def do_train(cfg, model, data_loader, optimizer, scheduler, checkpointer,
             device, checkpoint_period, arguments, meters):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    # meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    show_load_time = 0
    show_comp_time = 0
    end = time.time()
    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()
    # print('running with param decay')
    for iteration, (images, images_support, images_neg_support, targets,
                    img_ids, target_ids) in enumerate(data_loader, start_iter):
        load_time = time.time() - end
        show_load_time += load_time
        end = time.time()

        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
        if not pytorch_1_1_0_or_later:
            scheduler.step()

        images = images.to(device)
        images_support = images_support.to(
            device)  #[item.to(device) for item in images_support]
        images_neg_support = images_neg_support.to(
            device)  #[item.to(device) for item in images_neg_support]
        targets = [target.to(device) for target in targets]
        loss_dict = model(images,
                          images_support,
                          targets,
                          images_neg_supp=images_neg_support,
                          device=device)

        if 'rpn' in cfg.FEW_SHOT.UNTRAINED_KEYWORD:
            backward_losses = sum(
                loss for key, loss in loss_dict.items()
                if 'classifier' in key or 'box_reg' in key or 'rev' in key)
        else:
            backward_losses = sum(loss for loss in loss_dict.values(
            ))  #change to backward_losses for partial training

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        backward_losses.backward()
        optimizer.step()

        if pytorch_1_1_0_or_later:
            scheduler.step()

        batch_time = time.time() - end
        show_comp_time += batch_time
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 100 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "\neta: {eta}",
                    "load: {load_time}",
                    "comp: {comp_time}",
                    "iter: {iter}/{max_iter}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                    "\n{meters}",
                ]).format(
                    eta=eta_string,
                    load_time=show_load_time,
                    comp_time=show_comp_time,
                    iter=iteration,
                    max_iter=max_iter,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
            show_load_time = 0
            show_comp_time = 0
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))