예제 #1
0
    def __init__(self, cfg):
        super().__init__()

        self.device = torch.device(cfg.MODEL.DEVICE)

        self.backbone = build_backbone(cfg)
        self.sem_seg_head = build_sem_seg_head(cfg,
                                               self.backbone.output_shape())

        self.to(self.device)
예제 #2
0
    def __init__(self, backbone='resnet', output_stride=8, num_classes=10,
                 sync_bn=False, freeze_bn=False):
        super(DeepLab, self).__init__()

        BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            normalize,
        ])
        self.transform = transform
예제 #3
0
    def __init__(self,
                 backbone='resnet',
                 output_stride=16,
                 num_classes=21,
                 sync_bn=True,
                 freeze_bn=False):
        super(DeepLab, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        self.freeze_bn = freeze_bn
예제 #4
0
    def __init__(self,
                 nclass=21,
                 output_stride=None,
                 backbone='resnet101',
                 norm_layer=None,
                 loss_fn=None,
                 detach_backbone=True):
        super(DeepLabV3P, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        self.loss_fn = loss_fn

        self._up_kwargs = up_kwargs
        self.nclass = nclass

        self.backbone = build_backbone(backbone,
                                       output_stride,
                                       norm_layer,
                                       detach=detach_backbone)
        self.aspp = ASPP(backbone, output_stride, norm_layer)
        self.decoder = Decoder(self.nclass, backbone, norm_layer)
        self.varmaping = VarMapping(300, 128, BatchNorm=norm_layer)
    def __init__(self, cfg):
        super(DeeplabV3plus, self).__init__()
        self.backbone = None
        self.backbone_layers = None
        input_channel = 2048
        self.aspp = ASPP(dim_in=input_channel,
                         dim_out=cfg.MODEL_ASPP_OUTDIM,
                         rate=16 // cfg.MODEL_OUTPUT_STRIDE,
                         bn_mom=cfg.TRAIN_BN_MOM)
        self.dropout1 = nn.Dropout(0.5)
        # self.upsample4 = nn.UpsamplingBilinear2d(scale_factor=4)
        # self.upsample_sub = nn.UpsamplingBilinear2d(scale_factor=cfg.MODEL_OUTPUT_STRIDE//4)
        # self.upsample4 = nn.Upsample(scale_factor=4)
        self.upsample_sub = nn.Upsample(scale_factor=cfg.MODEL_OUTPUT_STRIDE //
                                        4)

        indim = 256
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(indim,
                      cfg.MODEL_SHORTCUT_DIM,
                      cfg.MODEL_SHORTCUT_KERNEL,
                      1,
                      padding=cfg.MODEL_SHORTCUT_KERNEL // 2,
                      bias=True),
            # SynchronizedBatchNorm2d(cfg.MODEL_SHORTCUT_DIM, momentum=cfg.TRAIN_BN_MOM),
            nn.BatchNorm2d(cfg.MODEL_SHORTCUT_DIM),
            nn.ReLU(inplace=True),
        )
        self.cat_conv = nn.Sequential(
            nn.Conv2d(cfg.MODEL_ASPP_OUTDIM + cfg.MODEL_SHORTCUT_DIM,
                      cfg.MODEL_ASPP_OUTDIM,
                      3,
                      1,
                      padding=1,
                      bias=True),
            # SynchronizedBatchNorm2d(cfg.MODEL_ASPP_OUTDIM, momentum=cfg.TRAIN_BN_MOM),
            nn.BatchNorm2d(cfg.MODEL_ASPP_OUTDIM),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(cfg.MODEL_ASPP_OUTDIM,
                      cfg.MODEL_ASPP_OUTDIM,
                      3,
                      1,
                      padding=1,
                      bias=True),
            # SynchronizedBatchNorm2d(cfg.MODEL_ASPP_OUTDIM, momentum=cfg.TRAIN_BN_MOM),
            nn.BatchNorm2d(cfg.MODEL_ASPP_OUTDIM),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM,
                                  cfg.num_classes,
                                  1,
                                  1,
                                  padding=0)
        self.backbone = build_backbone(cfg.MODEL_BACKBONE,
                                       os=cfg.MODEL_OUTPUT_STRIDE)
        # self.backbone_layers = self.backbone.get_layers()
        # print(len(self.backbone_layers))
        self.pointhead = PointRendSemSegHead()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
예제 #6
0
 def __init__(self, setting_dict):
     super(SSDdetector, self).__init__()
     self.setting_dict = setting_dict
     self.backbone = build_backbone(self.setting_dict["backbone"])
     self.box_head = build_boxhead(self.setting_dict["boxhead"])
def main():
    global best_precision, args
    best_precision = 0
    args = parse()


    if not len(args.data):
        raise Exception("error: No data set provided")


    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank

        if not args.cpu:
            torch.cuda.set_device(args.gpu)

        torch.distributed.init_process_group(backend='gloo',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    # Set the device
    device = torch.device('cpu' if args.cpu else 'cuda:' + str(args.gpu))





    #######################################################################
    #   Start DETR contruction
    #######################################################################

    # create DETR backbone

    # create backbone pulse counter
    if args.test:
        args.pulse_counter_arch = 'ResNet10'

    if args.local_rank==0 and args.verbose:
        print("=> creating backbone pulse counter '{}'".format(args.pulse_counter_arch))

    if args.pulse_counter_arch == 'ResNet18':
        backbone_pulse_counter = rn.ResNet18_Counter()
    elif args.pulse_counter_arch == 'ResNet34':
        backbone_pulse_counter = rn.ResNet34_Counter()
    elif args.pulse_counter_arch == 'ResNet50':
        backbone_pulse_counter = rn.ResNet50_Counter()
    elif args.pulse_counter_arch == 'ResNet101':
        backbone_pulse_counter = rn.ResNet101_Counter()
    elif args.pulse_counter_arch == 'ResNet152':
        backbone_pulse_counter = rn.ResNet152_Counter()
    elif args.pulse_counter_arch == 'ResNet10':
        backbone_pulse_counter = rn.ResNet10_Counter()
    else:
        print("Unrecognized {} architecture for the backbone pulse counter" .format(args.pulse_counter_arch))


    backbone_pulse_counter = backbone_pulse_counter.to(device)

    # create backbone feature predictor
    if args.test:
        args.feature_predictor_arch = 'ResNet10'

    if args.local_rank==0 and args.verbose:
        print("=> creating backbone feature predictor '{}'".format(args.feature_predictor_arch))

    if args.feature_predictor_arch == 'ResNet18':
        backbone_feature_predictor = rn.ResNet18_Custom()
    elif args.feature_predictor_arch == 'ResNet34':
        backbone_feature_predictor = rn.ResNet34_Custom()
    elif args.feature_predictor_arch == 'ResNet50':
        backbone_feature_predictor = rn.ResNet50_Custom()
    elif args.feature_predictor_arch == 'ResNet101':
        backbone_feature_predictor = rn.ResNet101_Custom()
    elif args.feature_predictor_arch == 'ResNet152':
        backbone_feature_predictor = rn.ResNet152_Custom()
    elif args.feature_predictor_arch == 'ResNet10':
        backbone_feature_predictor = rn.ResNet10_Custom()
    else:
        print("Unrecognized {} architecture for the backbone feature predictor" .format(args.feature_predictor_arch))


    backbone_feature_predictor = backbone_feature_predictor.to(device)



    # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel.
    if args.distributed:
        if args.cpu:
            backbone_pulse_counter = DDP(backbone_pulse_counter)
            backbone_feature_predictor = DDP(backbone_feature_predictor)
        else:
            backbone_pulse_counter = DDP(backbone_pulse_counter, device_ids=[args.gpu], output_device=args.gpu)
            backbone_feature_predictor = DDP(backbone_feature_predictor, device_ids=[args.gpu], output_device=args.gpu)

        if args.verbose:
            print('Since we are in a distributed setting the backbone componets are replicated here in local rank {}'
                                    .format(args.local_rank))



    # bring counter from a checkpoint
    if args.counter:
        # Use a local scope to avoid dangling references
        def bring_counter():
            if os.path.isfile(args.counter):
                print("=> loading backbone pulse counter '{}'" .format(args.counter))
                if args.cpu:
                    checkpoint = torch.load(args.counter, map_location='cpu')
                else:
                    checkpoint = torch.load(args.counter, map_location = lambda storage, loc: storage.cuda(args.gpu))

                loss_history_1 = checkpoint['loss_history']
                counter_error_history = checkpoint['Counter_error_history']
                best_error_1 = checkpoint['best_error']
                backbone_pulse_counter.load_state_dict(checkpoint['state_dict'])
                total_time_1 = checkpoint['total_time']
                print("=> loaded counter '{}' (epoch {})"
                                .format(args.counter, checkpoint['epoch']))
                print("Counter best precision saved was {}" .format(best_error_1))
                return best_error_1, backbone_pulse_counter, loss_history_1, counter_error_history, total_time_1
            else:
                print("=> no counter found at '{}'" .format(args.counter))
    
        best_error_1, backbone_pulse_counter, loss_history_1, counter_error_history, total_time_1 = bring_counter()
    else:
        raise Exception("error: No counter path provided")




    # bring predictor from a checkpoint
    if args.predictor:
        # Use a local scope to avoid dangling references
        def bring_predictor():
            if os.path.isfile(args.predictor):
                print("=> loading backbone feature predictor '{}'" .format(args.predictor))
                if args.cpu:
                    checkpoint = torch.load(args.predictor, map_location='cpu')
                else:
                    checkpoint = torch.load(args.predictor, map_location = lambda storage, loc: storage.cuda(args.gpu))

                loss_history_2 = checkpoint['loss_history']
                duration_error_history = checkpoint['duration_error_history']
                amplitude_error_history = checkpoint['amplitude_error_history']
                best_error_2 = checkpoint['best_error']
                backbone_feature_predictor.load_state_dict(checkpoint['state_dict'])
                total_time_2 = checkpoint['total_time']
                print("=> loaded predictor '{}' (epoch {})"
                                .format(args.predictor, checkpoint['epoch']))
                print("Predictor best precision saved was {}" .format(best_error_2))
                return best_error_2, backbone_feature_predictor, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 
            else:
                print("=> no predictor found at '{}'" .format(args.predictor))

        best_error_2, backbone_feature_predictor, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 = bring_predictor()
    else:
        raise Exception("error: No predictor path provided")



    # create backbone
    if args.local_rank==0 and args.verbose:
        print("=> creating backbone")

    if args.feature_predictor_arch == 'ResNet18':
        backbone=build_backbone(pulse_counter=backbone_pulse_counter,
                                feature_predictor=backbone_feature_predictor,
                                num_channels=512)
    elif args.feature_predictor_arch == 'ResNet34':
        backbone=build_backbone(pulse_counter=backbone_pulse_counter,
                                feature_predictor=backbone_feature_predictor,
                                num_channels=512)
    elif args.feature_predictor_arch == 'ResNet50':
        backbone=build_backbone(pulse_counter=backbone_pulse_counter,
                                feature_predictor=backbone_feature_predictor,
                                num_channels=2048)
    elif args.feature_predictor_arch == 'ResNet101':
        backbone=build_backbone(pulse_counter=backbone_pulse_counter,
                                feature_predictor=backbone_feature_predictor,
                                num_channels=2048)
    elif args.feature_predictor_arch == 'ResNet152':
        backbone=build_backbone(pulse_counter=backbone_pulse_counter,
                                feature_predictor=backbone_feature_predictor,
                                num_channels=2048)
    elif args.feature_predictor_arch == 'ResNet10':
        backbone=build_backbone(pulse_counter=backbone_pulse_counter,
                                feature_predictor=backbone_feature_predictor,
                                num_channels=512)
    else:
        print("Unrecognized {} architecture for the backbone feature predictor" .format(args.feature_predictor_arch))


    backbone = backbone.to(device)











    # create DETR transformer
    if args.local_rank==0 and args.verbose:
        print("=> creating transformer")

    if args.test:
        args.transformer_hidden_dim = 64
        args.transformer_num_heads = 2
        args.transformer_dim_feedforward = 256
        args.transformer_num_enc_layers = 2
        args.transformer_num_dec_layers = 2

    args.transformer_pre_norm = True
    transformer = build_transformer(hidden_dim=args.transformer_hidden_dim,
                                    dropout=args.transformer_dropout,
                                    nheads=args.transformer_num_heads,
                                    dim_feedforward=args.transformer_dim_feedforward,
                                    enc_layers=args.transformer_num_enc_layers,
                                    dec_layers=args.transformer_num_dec_layers,
                                    pre_norm=args.transformer_pre_norm)






    # create DETR in itself
    if args.local_rank==0 and args.verbose:
        print("=> creating DETR")

    detr = DT.DETR(backbone=backbone,
                   transformer=transformer,
                   num_classes=args.num_classes,
                   num_queries=args.num_queries)

    detr = detr.to(device)

    # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel.
    if args.distributed:
        if args.cpu:
            detr = DDP(detr)
        else:
            detr = DDP(detr, device_ids=[args.gpu], output_device=args.gpu)

        if args.verbose:
            print('Since we are in a distributed setting DETR model is replicated here in local rank {}'
                                    .format(args.local_rank))



    # Set matcher
    if args.local_rank==0 and args.verbose:
        print("=> set Hungarian Matcher")

    matcher = mtchr.HungarianMatcher(cost_class=args.cost_class,
                                     cost_bsegment=args.cost_bsegment,
                                     cost_giou=args.cost_giou)





    # Set criterion
    if args.local_rank==0 and args.verbose:
        print("=> set criterion for the loss")

    weight_dict = {'loss_ce': args.loss_ce,
                   'loss_bsegment': args.loss_bsegment,
                   'loss_giou': args.loss_giou}

    losses = ['labels', 'segments', 'cardinality']

    criterion = DT.SetCriterion(num_classes=args.num_classes,
                                matcher=matcher,
                                weight_dict=weight_dict,
                                eos_coef=args.eos_coef,
                                losses=losses)

    criterion = criterion.to(device)



    # Set optimizer
    optimizer = Model_Util.get_DETR_optimizer(detr, args)
    if args.local_rank==0 and args.verbose:
        print('Optimizer used for this run is {}'.format(args.optimizer))


    # Set learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lrsp,
                                                              args.lrm)



    total_time = Utilities.AverageMeter()
    loss_history = []
    precision_history = []
    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'" .format(args.resume))
                if args.cpu:
                    checkpoint = torch.load(args.resume, map_location='cpu')
                else:
                    checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))

                loss_history = checkpoint['loss_history']
                precision_history = checkpoint['precision_history']
                start_epoch = checkpoint['epoch']
                best_precision = checkpoint['best_precision']
                detr.load_state_dict(checkpoint['state_dict'])
                criterion.load_state_dict(checkpoint['criterion'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
                total_time = checkpoint['total_time']
                print("=> loaded checkpoint '{}' (epoch {})"
                                .format(args.resume, checkpoint['epoch']))
                return start_epoch, detr, criterion, optimizer, lr_scheduler, loss_history, precision_history, total_time, best_precision 
            else:
                print("=> no checkpoint found at '{}'" .format(args.resume))
    
        args.start_epoch, detr, criterion, optimizer, lr_scheduler, loss_history, precision_history, total_time, best_precision = resume()







    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir= args.data[1]

    if args.test:
        training_f = h5py.File(traindir + '/train_toy.h5', 'r')
        validation_f = h5py.File(valdir + '/validation_toy.h5', 'r')
    else:
        training_f = h5py.File(traindir + '/train.h5', 'r')
        validation_f = h5py.File(valdir + '/validation.h5', 'r')


    # this is the dataset for training
    sampling_rate = 10000                   # This is the number of samples per second of the signals in the dataset
    if args.test:
        number_of_concentrations = 2        # This is the number of different concentrations in the dataset
        number_of_durations = 2             # This is the number of different translocation durations per concentration in the dataset
        number_of_diameters = 4             # This is the number of different translocation durations per concentration in the dataset
        window = 0.5                        # This is the time window in seconds
        length = 20                         # This is the time of a complete signal for certain concentration and duration
    else:
        number_of_concentrations = 20       # This is the number of different concentrations in the dataset
        number_of_durations = 5             # This is the number of different translocation durations per concentration in the dataset
        number_of_diameters = 15            # This is the number of different translocation durations per concentration in the dataset
        window = 0.5                        # This is the time window in seconds
        length = 20                         # This is the time of a complete signal for certain concentration and duration

    # Training Artificial Data Loader
    TADL = Artificial_DataLoader(args.world_size, args.local_rank, device, training_f, sampling_rate,
                                 number_of_concentrations, number_of_durations, number_of_diameters,
                                 window, length, args.batch_size)

    # this is the dataset for validating
    if args.test:
        number_of_concentrations = 2        # This is the number of different concentrations in the dataset
        number_of_durations = 2             # This is the number of different translocation durations per concentration in the dataset
        number_of_diameters = 4             # This is the number of different translocation durations per concentration in the dataset
        window = 0.5                        # This is the time window in seconds
        length = 10                         # This is the time of a complete signal for certain concentration and duration
    else:
        number_of_concentrations = 20       # This is the number of different concentrations in the dataset
        number_of_durations = 5             # This is the number of different translocation durations per concentration in the dataset
        number_of_diameters = 15            # This is the number of different translocation durations per concentration in the dataset
        window = 0.5                        # This is the time window in seconds
        length = 10                         # This is the time of a complete signal for certain concentration and duration

    # Validating Artificial Data Loader
    VADL = Artificial_DataLoader(args.world_size, args.local_rank, device, validation_f, sampling_rate,
                                 number_of_concentrations, number_of_durations, number_of_diameters,
                                 window, length, args.batch_size)

    if args.verbose:
        print('From rank {} training shard size is {}'. format(args.local_rank, TADL.get_number_of_avail_windows()))
        print('From rank {} validation shard size is {}'. format(args.local_rank, VADL.get_number_of_avail_windows()))








































    if args.run:
        arguments = {'model': detr,
                     'device': device,
                     'epoch': 0,
                     'VADL': VADL}

        if args.local_rank == 0:
            run_model(args, arguments)

        return

    #if args.statistics:
        #arguments = {'model': model,
                     #'device': device,
                     #'epoch': 0,
                     #'VADL': VADL}

        #[duration_errors, amplitude_errors] = compute_error_stats(args, arguments)
        #if args.local_rank == 0:
            #plot_stats(VADL, duration_errors, amplitude_errors)

        #return


    #if args.evaluate:
        #arguments = {'model': model,
                     #'device': device,
                     #'epoch': 0,
                     #'VADL': VADL}

        #[duration_error, amplitude_error] = validate(args, arguments)
        #print('##Duration error {0}\n'
              #'##Amplitude error {1}'.format(
              #duration_error,
              #amplitude_error))

        #return

    if args.plot_training_history and args.local_rank == 0:
        Model_Util.plot_detector_stats(loss_history, precision_history)
        hours = int(total_time.sum / 3600)
        minutes = int((total_time.sum % 3600) / 60)
        seconds = int((total_time.sum % 3600) % 60)
        print('The total training time was {} hours {} minutes and {} seconds' .format(hours, minutes, seconds))
        hours = int(total_time.avg / 3600)
        minutes = int((total_time.avg % 3600) / 60)
        seconds = int((total_time.avg % 3600) % 60)
        print('while the average time during one epoch of training was {} hours {} minutes and {} seconds' .format(hours, minutes, seconds))
        return


    for epoch in range(args.start_epoch, args.epochs):
        
        arguments = {'detr': detr,
                     'criterion': criterion,
                     'optimizer': optimizer,
                     'device': device,
                     'epoch': epoch,
                     'TADL': TADL,
                     'VADL': VADL,
                     'loss_history': loss_history,
                     'precision_history': precision_history}

        # train for one epoch
        epoch_time, avg_batch_time = train(args, arguments)
        total_time.update(epoch_time)

        # validate every val_freq epochs
        if epoch%args.val_freq == 0 and epoch != 0:
            # evaluate on validation set
            print("\nValidating ...\nComputing mean average precision (mAP) for epoch {}" .format(epoch))
            precision = validate(args, arguments)
        else:
            precision = best_precision

        #if args.test:
            #break

        lr_scheduler.step()
        # remember the best detr and save checkpoint
        if args.local_rank == 0:
            if epoch%args.val_freq == 0:
                print('From validation we have precision is {} while best_precision is {}'.format(precision, best_precision))

            is_best = precision > best_precision
            best_precision = max(precision, best_precision)
            Model_Util.save_checkpoint({
                    'arch': 'DETR_' + args.feature_predictor_arch,
                    'epoch': epoch + 1,
                    'best_precision': best_precision,
                    'state_dict': detr.state_dict(),
                    'criterion': criterion.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss_history': loss_history,
                    'precision_history': precision_history,
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'total_time': total_time
            }, is_best)

            print('##Detector precision {0}\n'
                  '##Perf {1}'.format(
                  precision,
                  args.total_batch_size / avg_batch_time))