Beispiel #1
0
 def __init__(self, cfg, pretrained=True, is_train=True):
     super(Model, self).__init__()
     self.backbone = ResNet(cfg, pretrained)
     self.rpn = RPN(cfg, is_train)
     self.roi = BoxRoI(cfg, is_train)
     # 如果使用混合精度, 需要手动将其转换为半精度运算
     if cfg.TRAIN.MIX_LEVEL == 'O1':
         self.roi.head.pool.forward = amp.half_function(
             self.roi.head.pool.forward)
Beispiel #2
0
    def __init__(self, K=3):
        super(Combined_VGG16, self).__init__()
        self.K = K
        vgg = torchvision.models.vgg16(pretrained=True)
        self.pretrained_features = nn.Sequential(
            *list(vgg.features._modules.values())[:23])
        self.new_features = nn.Sequential(
            OrderedDict([
                ('conv5_1',
                 nn.Conv2d(512,
                           512,
                           kernel_size=3,
                           stride=1,
                           padding=2,
                           dilation=2)),
                ('relu5_1', nn.ReLU(inplace=True)),
                ('conv5_2',
                 nn.Conv2d(512,
                           512,
                           kernel_size=3,
                           stride=1,
                           padding=2,
                           dilation=2)),
                ('relu5_2', nn.ReLU(inplace=True)),
                ('conv5_3',
                 nn.Conv2d(512,
                           512,
                           kernel_size=3,
                           stride=1,
                           padding=2,
                           dilation=2)),
                ('relu5_3', nn.ReLU(inplace=True)),
            ]))
        self.roi_size = (7, 7)
        self.roi_spatial_scale = 0.125
        self.roi_pool = RoIPool(self.roi_size, self.roi_spatial_scale)
        copy_parameters(self.new_features.conv5_1, vgg.features[24])
        copy_parameters(self.new_features.conv5_2, vgg.features[26])
        copy_parameters(self.new_features.conv5_3, vgg.features[28])

        self.fc67 = nn.Sequential(*list(vgg.classifier._modules.values())[:-1])
        self.fc8c = nn.Linear(4096, 20)
        self.fc8d = nn.Linear(4096, 20)
        self.c_softmax = nn.Softmax(dim=1)
        self.d_softmax = nn.Softmax(dim=0)

        self.ic_score1 = nn.Linear(4096, 21)
        self.ic_score2 = nn.Linear(4096, 21)
        self.ic_score3 = nn.Linear(4096, 21)

        self.ic_prob1 = nn.Softmax(dim=1)
        self.ic_prob2 = nn.Softmax(dim=1)
        self.ic_prob3 = nn.Softmax(dim=1)

        self.roi_pool.forward = amp.half_function(self.roi_pool.forward)
Beispiel #3
0
        dropout_mask = output[-1]
        ctx.p = p
        return output[0]

    @staticmethod
    @custom_bwd
    def backward(ctx, *grad_o):
        p = ctx.p
        grads = fused_mlp_relu.backward(p, grad_o[0], ctx.outputs,
                                        ctx.saved_tensors)
        del ctx.outputs
        return (None, *grads)


if fused_mlp_relu:
    mlp_relu_function = half_function(MlpReluFunction.apply)
else:
    mlp_relu_function = None


class MlpSiluFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, p, *args):
        output = fused_mlp_silu.forward(p, args)
        ctx.save_for_backward(*args)
        ctx.outputs = output
        dropout_mask = output[-1]
        ctx.p = p
        return output[0]
Beispiel #4
0
        dropout_mask = outputs[-1]
        ctx.p = p
        return outputs[0], dropout_mask

    @staticmethod
    @custom_bwd
    def backward(ctx, *grad_o):
        p = ctx.p
        grads = fused_mlp_gelu.backward(p, grad_o[0], ctx.outputs,
                                        ctx.saved_tensors)
        del ctx.outputs
        return (None, *grads)


if fused_mlp_agelu:
    mlp_agelu_function = half_function(MlpAGeLUFunction.apply)
else:
    mlp_agelu_function = None

if fused_mlp_gelu:
    mlp_gelu_function = half_function(MlpGeLUFunction.apply)
else:
    mlp_gelu_function = None


class SwishFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, inp):
        ctx.save_for_backward(inp)
        return silu_cuda.forward(inp)
Beispiel #5
0
def get_trainer(args, model, train_loader, optimizer, lr_scheduler, device,
                tfboard):

    if args.apex:
        model.roi_heads.box_roi_pool.forward = \
            amp.half_function(model.roi_heads.box_roi_pool.forward)
        if hasattr(model.roi_heads, 'pose_attention_net'):
            model.roi_heads.pose_attention_net.pool.forward = \
                amp.half_function(model.roi_heads.pose_attention_net.pool.forward)
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    model_without_ddp = model
    if args.distributed:
        if args.apex:
            model = apex.parallel.convert_syncbn_model(model)
            model = apex.parallel.DistributedDataParallel(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                find_unused_parameters=True)
        model_without_ddp = model.module

    if args.resume is not None:
        args, model_without_ddp, optimizer, lr_scheduler = resume_from_checkpoint(
            args, model_without_ddp, optimizer, lr_scheduler)

    def _update_model(engine, data):
        """
        Args:
        :param engine:handle
        :param data:batch data
        :return:data to be stored in the engine`s state.
        """
        images, targets = ship_data_to_cuda(data, device)

        loss_dict = model(images, targets)

        losses = args.train.w_RPN_loss_cls * loss_dict['loss_objectness'] \
            + args.train.w_RPN_loss_box * loss_dict['loss_rpn_box_reg'] \
            + args.train.w_RCNN_loss_bbox * loss_dict['loss_box_reg'] \
            + args.train.w_RCNN_loss_cls * loss_dict['loss_detection'] \
            + args.train.w_OIM_loss_oim * loss_dict['loss_reid']

        # reduce losses over all GPUs for logging purposes
        if engine.state.iteration % args.train.disp_interval == 0:
            loss_dict_reduced = reduce_dict(loss_dict)
            losses_reduced = args.train.w_RPN_loss_cls * loss_dict_reduced['loss_objectness'] \
                + args.train.w_RPN_loss_box * loss_dict_reduced['loss_rpn_box_reg'] \
                + args.train.w_RCNN_loss_bbox * loss_dict_reduced['loss_box_reg'] \
                + args.train.w_RCNN_loss_cls * loss_dict_reduced['loss_detection'] \
                + args.train.w_OIM_loss_oim * loss_dict_reduced['loss_reid']
            loss_value = losses_reduced.item()
            state = dict(loss_value=loss_value,
                         lr=optimizer.param_groups[0]['lr'])
            state.update(loss_dict_reduced)
        else:
            state = None

        optimizer.zero_grad()
        if args.apex:
            with amp.scale_loss(losses, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            losses.backward()
        if args.train.clip_gradient > 0:
            clip_grad_norm_(model.parameters(), args.train.clip_gradient)
        optimizer.step()

        return state

    trainer = Engine(_update_model)

    @trainer.on(Events.STARTED)
    def _init_run(engine):
        engine.state.epoch = args.train.start_epoch
        engine.state.iteration = args.train.start_epoch * len(train_loader)

    @trainer.on(Events.EPOCH_STARTED)
    def _init_epoch(engine):
        if engine.state.epoch == 1 and args.train.lr_warm_up:
            warmup_factor = 1. / 1000
            warmup_iters = len(train_loader) - 1
            engine.state.sub_scheduler = warmup_lr_scheduler(
                optimizer, warmup_iters, warmup_factor)
        lucky_bunny(engine.state.epoch)
        engine.state.metric_logger = MetricLogger()

    @trainer.on(Events.ITERATION_STARTED)
    def _init_iter(engine):
        if engine.state.iteration % args.train.disp_interval == 0:
            engine.state.start = time.time()  ## 从当前时间开始

    @trainer.on(Events.ITERATION_COMPLETED)
    def _post_iter(engine):
        if engine.state.epoch == 1 and args.train.lr_warm_up:  # epoch start from 1
            engine.state.sub_scheduler.step()

        if engine.state.iteration % args.train.disp_interval == 0:
            # Update logger
            batch_time = time.time() - engine.state.start
            engine.state.metric_logger.update(batch_time=batch_time)
            engine.state.metric_logger.update(**engine.state.output)
            if hasattr(engine.state, 'debug_info'):
                engine.state.metric_logger.update(**engine.state.debug_info)
            # Print log on console
            step = (engine.state.iteration - 1) % len(train_loader) + 1
            engine.state.metric_logger.print_log(engine.state.epoch, step,
                                                 len(train_loader))
            # Record log on tensorboard
            if args.train.use_tfboard and is_main_process():
                for k, v in engine.state.metric_logger.meters.items():
                    if 'loss' in k:
                        k = k.replace('loss_', 'Loss/')
                    if 'num' in k:
                        tfboard.add_scalars('Debug/fg_bg_ratio', {k: v.avg},
                                            engine.state.iteration)
                    else:
                        tfboard.add_scalar(k, v.avg, engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def _post_epoch(engine):
        lr_scheduler.step()
        if is_main_process():
            save_name = osp.join(args.path, 'checkpoint.pth')
            save_checkpoint(
                {
                    'epoch': engine.state.epoch,
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict()
                }, save_name)
            # print(hue.good('save model: {}'.format(save_name)))
            print(
                '===============save model: {}======================='.format(
                    save_name))

    return trainer
Beispiel #6
0
        del ctx.outputs
        return (None, *grads)


# if fused_relu_mlp is not None:
#     mlp_relu_function = half_function(MlpReluFunction.apply)
# else:
#     mlp_relu_function = None
#
# if fused_gelu_mlp is not None:
#     mlp_gelu_function = half_function(MlpGeluFunction.apply)
# else:
#     mlp_gelu_function = None

if fused_mlp_relu:
    mlp_relu_function = half_function(MlpReluFunction.apply)
else:
    mlp_relu_function = None

if fused_mlp_silu:
    mlp_silu_function = half_function(MlpSiluFunction.apply)
else:
    mlp_silu_function = None


class SwishFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, inp):
        ctx.save_for_backward(inp)
        return silu_cuda.forward(inp)
def main(args):

    # distributed training variable
    args.gpu = 0
    args.world_size = 1

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    ### distributed deep learn parameters
    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    device = torch.device(args.device)
    log_interval = 20

    train_transforms = Compose(
        [CocoDetectProcessor(),
         ToTensor(),
         RandomHorizontalFlip(0.5)])
    val_transforms = Compose([CocoDetectProcessor(), ToTensor()])

    ### Coco DataSet Processors
    train_set = CocoDetection(
        os.path.join(args.data, 'train2017'),
        os.path.join(args.data, 'annotations', 'instances_train2017.json'),
        train_transforms)
    val_set = CocoDetection(
        os.path.join(args.data, 'val2017'),
        os.path.join(args.data, 'annotations', 'instances_val2017.json'),
        val_transforms)

    train_set = coco_remove_images_without_annotations(train_set)

    # Coco Dataset Samplers
    train_sampler = torch.utils.data.RandomSampler(train_set)
    test_sampler = torch.utils.data.SequentialSampler(val_set)
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                        args.batch_size,
                                                        drop_last=True)

    ### pytorch dataloaders
    # cannot increase batch size till we sort the resolutions
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=collate_fn)

    test_loader = torch.utils.data.DataLoader(val_set,
                                              batch_size=1,
                                              sampler=test_sampler,
                                              num_workers=args.workers,
                                              collate_fn=collate_fn)

    # instantiate model
    if args.arch in model_names:
        model = models.__dict__[args.arch](pretrained=False)
    elif args.arch in local_model_names:
        model = local_models.__dict__[args.arch](pretrained=False)

    model.to(device)

    ## declare optimiser
    params = [p for p in model.parameters() if p.requires_grad]

    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.fp16:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=args.keep_batchnorm_fp32,
            loss_scale="dynamic"
            if args.dynamic_loss_scale else args.static_loss_scale)

    model.roi_heads.box_roi_pool.forward = \
    amp.half_function(model.roi_heads.box_roi_pool.forward)

    if args.distributed:
        model = DDP(model)

    #wandb.watch(model)

    # trigger train loop

    for epoch in range(10):

        # train one epoch
        batch_loop(model, optimizer, train_loader, device, epoch, args.fp16)
        return outputs[0], dropout_mask, residual_mask

    @staticmethod
    @custom_bwd
    def backward(ctx, *grad_o):
        p = ctx.p
        r_p = ctx.r_p
        grads = fused_mlp_gelu_dropout_add.backward(p, r_p, grad_o[0],
                                                    ctx.outputs,
                                                    ctx.saved_tensors)
        del ctx.outputs
        return (None, None, *grads)


if fused_mlp_gelu_dropout_add:
    mlp_gelu_dropout_add_function = half_function(
        MlpGeLUDropoutAddFunction.apply)
else:
    mlp_gelu_dropout_add_function = None


class SwishFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, inp):
        ctx.save_for_backward(inp)
        return silu_cuda.forward(inp)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        inp, = ctx.saved_tensors
Beispiel #9
0
        p = ctx.p
        if not ctx.recompute:
            grads = fused_mlp_relu.backward(p, grad_o[0], ctx.outputs,
                                            ctx.saved_tensors)
            del ctx.outputs
        else:
            grads = fused_mlp_relu.backward_recompute(p, grad_o[0],
                                                      ctx.dropout_mask,
                                                      ctx.saved_tensors)
            del ctx.dropout_mask

        return (None, None, *grads)


if fused_mlp_relu:
    mlp_relu_function = half_function(MlpReluFunction.apply)
else:
    mlp_relu_function = None


class MlpSiluFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, p, recompute, *args):
        output = fused_mlp_silu.forward(p, args)
        ctx.save_for_backward(*args)
        ctx.outputs = output
        dropout_mask = output[-1]
        ctx.p = p
        return output[0]
Beispiel #10
0
 def half_function(fn):
     return amp.half_function(fn)