Пример #1
0
    def test_check_size_average(self):
        criterion = losses.FocalLoss(size_average=True)
        loss_mean = criterion(self.outputs, self.targets)

        criterion = losses.FocalLoss(size_average=False)
        loss_sum = criterion(self.outputs, self.targets)

        self.assertEqual(loss_mean * len(self.targets), loss_sum)
    def __init__(self, labelfile, imagesdir):

        self.width, self.height = 800, 800
        self.mean = [0.408, 0.447, 0.47]
        self.std = [0.289, 0.274, 0.278]
        self.batch_size = 18
        self.lr = 1e-4
        self.gpus = [2]  #[0, 1, 2, 3]
        self.gpu_master = self.gpus[0]
        self.model = DBFace(has_landmark=True,
                            wide=64,
                            has_ext=True,
                            upmode="UCBA")
        self.model.init_weights()
        self.model = nn.DataParallel(self.model, device_ids=self.gpus)
        self.model.cuda(device=self.gpu_master)
        self.model.train()

        self.focal_loss = losses.FocalLoss()
        self.giou_loss = losses.GIoULoss()
        self.landmark_loss = losses.WingLoss(w=2)
        self.train_dataset = LDataset(labelfile,
                                      imagesdir,
                                      mean=self.mean,
                                      std=self.std,
                                      width=self.width,
                                      height=self.height)
        self.train_loader = DataLoader(dataset=self.train_dataset,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=24)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.per_epoch_batchs = len(self.train_loader)
        self.iter = 0
        self.epochs = 150
Пример #3
0
    def __init__(self, num_classes, block, layers, groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 dropout1=0.25, dropout2=0.25, alpha=0.25, gamma=2.0,
                 loss_with_no_bboxes=False, no_bboxes_alpha=0.5, no_bboxes_gamma=2.0):
        #Has been changed to ResNext(customized by Yu Han Huang)
        self.inplanes = 64
        super(ResNet, self).__init__()
        #add self.dilation, width_per_group, replace_stride_with_dilation (customized by Yu Han Huang)
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        #add dilate=replace_stride_with_dilation (customized by Yu Han Huang)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        #add C2 layer_size to fpn_sizes (customized by Yu Han Huang)
        if block == BasicBlock:
            fpn_sizes = [self.layer1[layers[0]-1].conv2.out_channels, self.layer2[layers[1]-1].conv2.out_channels,
             self.layer3[layers[2]-1].conv2.out_channels, self.layer4[layers[3]-1].conv2.out_channels]
        elif block == BasicBlock:
            fpn_sizes = [self.layer1[layers[0]-1].conv3.out_channels, self.layer2[layers[1]-1].conv3.out_channels,
             self.layer3[layers[2]-1].conv3.out_channels, self.layer4[layers[3]-1].conv3.out_channels]
        #add fpn_sizes[0] into PyramidFeatures (customized by Yu Han Huang)
        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2], fpn_sizes[3])
        self.regressionModel = RegressionModel(256)
        self.classificationModel = ClassificationModel(256, num_classes=num_classes, dropout1=dropout1, dropout2=dropout2)
        self.anchors = Anchors()
        self.regressBoxes = BBoxTransform()
        self.clipBoxes = ClipBoxes()
        #add arguments alpha, gamma loss_with_no_bboxes, no_bboxes_alpha, no_bboxes_gamma(customized by Yu Han Huang)
        self.focalLoss = losses.FocalLoss(alpha=alpha, gamma=gamma, loss_with_no_bboxes=loss_with_no_bboxes, no_bboxes_alpha=no_bboxes_alpha, no_bboxes_gamma=no_bboxes_gamma)
                
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        prior = 0.01
        
        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log((1.0-prior)/prior))
        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)
        self.freeze_bn()
    def __init__(self, num_classes, block, pretrained=False, phi=0):
        self.inplanes = w_bifpn[phi]
        super(EfficientDet, self).__init__()
        efficientnet = EfficientNet.from_pretrained(f'efficientnet-b{phi}')
        blocks = []
        count = 0
        fpn_sizes = []
        for block in efficientnet._blocks:
            blocks.append(block)
            if block._depthwise_conv.stride == [2, 2]:
                count += 1
                fpn_sizes.append(block._project_conv.out_channels)
                if len(fpn_sizes) >= 4:
                    break

        self.efficientnet = nn.Sequential(efficientnet._conv_stem,
                                          efficientnet._bn0, *blocks)
        num_layers = min(phi + 2, 8)
        self.fpn = BiFPN(fpn_sizes[1:],
                         feature_size=w_bifpn[phi],
                         num_layers=num_layers)

        d_class = 3 + (phi // 3)
        self.regressionModel = RegressionModel(w_bifpn[phi],
                                               feature_size=w_bifpn[phi],
                                               d_class=d_class)
        self.classificationModel = ClassificationModel(
            w_bifpn[phi],
            feature_size=w_bifpn[phi],
            d_class=d_class,
            num_classes=num_classes)

        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()

        self.focalLoss = losses.FocalLoss().cuda()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01

        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))

        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)

        self.freeze_bn()
Пример #5
0
    def __init__(self, num_classes, block, layers):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        if block == BasicBlock:
            fpn_sizes = [self.layer2[layers[1]-1].conv2.out_channels, self.layer3[layers[2]-1].conv2.out_channels, self.layer4[layers[3]-1].conv2.out_channels]
        elif block == Bottleneck:
            fpn_sizes = [self.layer2[layers[1]-1].conv3.out_channels, self.layer3[layers[2]-1].conv3.out_channels, self.layer4[layers[3]-1].conv3.out_channels]

        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])
        self.num_classes = num_classes

        self.regressionModel = RegressionModel(512)
        #print(num_classes)
        self.classificationModel = ClassificationModel(512, num_classes=num_classes)
        self.reidModel = ReidModel(512, num_classes=num_classes)


        self.anchors = Anchors()
        self.regressBoxes = BBoxTransform()
        self.clipBoxes = ClipBoxes()
        
        self.focalLoss = losses.FocalLoss()
        self.reidfocalLoss = losses.FocalLossReid()
                
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01

        '''
        self.output_reg = nn.Conv2d(feature_size, num_anchors*8, kernel_size=3, padding=1)

        self.output_cls'''
        
        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log((1.0-prior)/prior))
        
        self.reidModel.output.weight.data.fill_(0)
        self.reidModel.output.bias.data.fill_(-math.log((1.0-prior)/prior))
      
        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)

        self.freeze_bn()
Пример #6
0
    def __init__(self, num_classes, block, layers, normalization='batch_norm'):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.normalization = normalization


        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if normalization == 'batch_norm':
            self.bn1 = nn.BatchNorm2d(64)
        else:
            self.bn1 = nn.GroupNorm(num_groups=8, num_channels=64)  # Note: Does not use preloaded imagenet weights, as BatchNorm does
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        if block == BasicBlock:
            fpn_sizes = [self.layer2[layers[1]-1].conv2.out_channels, self.layer3[layers[2]-1].conv2.out_channels, self.layer4[layers[3]-1].conv2.out_channels]
        elif block == Bottleneck:
            fpn_sizes = [self.layer2[layers[1]-1].conv3.out_channels, self.layer3[layers[2]-1].conv3.out_channels, self.layer4[layers[3]-1].conv3.out_channels]

        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])

        self.regressionModel = RegressionModel(256)
        self.classificationModel = ClassificationModel(256, num_classes=num_classes)

        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()
        
        self.focalLoss = losses.FocalLoss()
                
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            # elif :
                
                # raise NotImplementedError('Not Implemented: Contact @Vishnu')

        prior = 0.01
        
        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log((1.0-prior)/prior))

        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)

        self.freeze_bn()
Пример #7
0
    def __init__(self, num_classes, backbone_network, fpn_sizes):
        """[summary]

        Args:
            num_classes ([int]): [description]
            backbone_network ([str]): [description]
            fpn_sizes ([list]): [number of channels
                                    in each backbone feature map]
        """
        self.inplanes = 64
        super(RetinaNet, self).__init__()
        # fpn_sizes = [160, 272, 448]
        # fpn_sizes = [56, 160, 448]
        # for b4
        # fpn_sizes = [160, 272, 448]

        # for b0
        # fpn_sizes = [112,192,1280]
        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])

        self.regressionModel = RegressionModel(256)
        self.classificationModel = ClassificationModel(256,
                                                       num_classes=num_classes)

        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()

        self.focalLoss = losses.FocalLoss()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01
        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))

        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)

        self.freeze_bn()

        self.efficientnet = backbone_network
Пример #8
0
def load_loss(name: str, **kwargs) -> nn.Module:
    if name == 'CrossEntropyLoss':
        return nn.CrossEntropyLoss(**kwargs)
    elif name == 'FocalLoss':
        return losses.FocalLoss(**kwargs)
    elif name == 'ComboLoss':
        return losses.ComboLoss(**kwargs)
    elif name == 'JaccardLoss':
        return losses.JaccardLoss(**kwargs)
    else:
        attributes: Tuple[str, ...] = (
            'CrossEntropyLoss', 'FocalLoss', 'ComboLoss', 'JaccardLoss'
        )
        raise ValueError(f'name must be in {attributes}.')
Пример #9
0
    def __init__(self, num_classes, phi):
        feature_size = feature_sizes[phi]
        super(EfficientDet, self).__init__()

        self.backbone = geffnets[phi](pretrained=True,
                                      drop_rate=0.25,
                                      drop_connect_rate=0.2)

        # Get backbone feature sizes.
        fpn_sizes = [40, 80, 192]

        self.fpn = [
            PyramidFeatures(fpn_sizes, feature_size=feature_size,
                            index=index).cuda()
            for index in range(min(2 + phi, 8))
        ]

        self.regressionModel = RegressionModel(phi, feature_size=feature_size)
        self.classificationModel = ClassificationModel(
            phi, feature_size=feature_size, num_classes=num_classes)

        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()

        self.focalLoss = losses.FocalLoss()

        prior = 0.01

        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))

        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)
Пример #10
0
def build_detection_loss(saved_for_loss, anno):
    '''
    :param saved_for_loss: [classifications, regressions, anchors]
    :param anno: annotations
    :return: classification_loss, regression_loss
    '''
    saved_for_log = OrderedDict()
    classifications, regressions, anchors = saved_for_loss

    # Compute losses
    focalLoss = losses.FocalLoss()
    classification_loss, regression_loss = focalLoss(classifications,
                                                     regressions, anchors,
                                                     anno)
    classification_loss = classification_loss.mean()
    regression_loss = regression_loss.mean()
    total_loss = classification_loss + regression_loss

    # Get value from Tensor and save for log
    saved_for_log['total_loss'] = total_loss.item()
    saved_for_log['classification_loss'] = classification_loss.item()
    saved_for_log['regression_loss'] = regression_loss.item()

    return total_loss, saved_for_log
Пример #11
0
    def __init__(self,
                 num_classes,
                 block,
                 layers,
                 max_boxes,
                 score_threshold,
                 seg_level,
                 alphabet,
                 train_htr,
                 htr_gt_box,
                 ner_branch=False,
                 binary_classifier=True):
        self.inplanes = 64
        self.pool_h = 2
        self.pool_w = 400
        self.forward_transcription = False
        self.max_boxes = max_boxes
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.downsampling_factors = [8, 16, 32, 64, 128]
        self.epochs_only_det = 1
        self.score_threshold = score_threshold
        self.alphabet = alphabet
        self.train_htr = train_htr
        self.binary_classifier = binary_classifier
        self.htr_gt_box = htr_gt_box
        self.num_classes = num_classes
        self.ner_branch = ner_branch

        if block == BasicBlock:
            fpn_sizes = [
                self.layer2[layers[1] - 1].conv2.out_channels,
                self.layer3[layers[2] - 1].conv2.out_channels,
                self.layer4[layers[3] - 1].conv2.out_channels
            ]
        elif block == Bottleneck:
            fpn_sizes = [
                self.layer2[layers[1] - 1].conv3.out_channels,
                self.layer3[layers[2] - 1].conv3.out_channels,
                self.layer4[layers[3] - 1].conv3.out_channels
            ]

        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])

        self.anchors = Anchors(seg_level=seg_level)
        self.regressionModel = RegressionModel(
            num_features_in=256, num_anchors=self.anchors.num_anchors)
        self.recognitionModel = RecognitionModel(feature_size=256,
                                                 pool_h=self.pool_h,
                                                 alphabet_len=len(alphabet))
        if ner_branch:
            self.nerModel = NERModel(feature_size=256,
                                     pool_h=self.pool_h,
                                     n_classes=num_classes,
                                     pool_w=self.pool_w)
        self.classificationModel = ClassificationModel(
            num_features_in=256,
            num_anchors=self.anchors.num_anchors,
            num_classes=num_classes)
        self.boxSampler = BoxSampler('train', self.score_threshold)
        self.sorter = RoISorter()
        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()

        self.focalLoss = losses.FocalLoss()
        if ner_branch:
            self.nerLoss = losses.NERLoss()
        self.transcriptionLoss = losses.TranscriptionLoss()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01

        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))

        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)

        self.recognitionModel.output.weight.data.fill_(0)

        self.recognitionModel.output.bias.data.fill_(-math.log((1.0 - prior) /
                                                               prior))
        if ner_branch:
            self.nerModel.output.weight.data.fill_(0)

            self.nerModel.output.bias.data.fill_(-math.log((1.0 - prior) /
                                                           prior))
        self.freeze_bn()
Пример #12
0
    def __init__(self, args, image_network, decoder_network=None):
        super().__init__()

        self.args = args
        self.blobs_strategy = self.args.blobs_strategy
        self.model_type = self.args.finetune_obj.split("_")[0]

        self.num_classes = 9
        self.n_blobs = 3

        # print(image_network)
        self.image_network = image_network
        # print(self.image_network)
        self.init_layers = self.image_network[0:4]
        self.block1 = self.image_network[4]
        self.block2 = self.image_network[5]
        self.block3 = self.image_network[6]
        self.block4 = self.image_network[7]

        self.decoder_network = decoder_network

        if "encoder" in self.blobs_strategy:
            if "resnet18" in self.args.network_base or "resnet34" in self.args.network_base:
                fpn_sizes = [
                    self.block2[-1].conv2.out_channels,
                    self.block3[-1].conv2.out_channels,
                    self.block4[-1].conv2.out_channels
                ]
            else:
                fpn_sizes = [
                    self.block2[-1].conv3.out_channels,
                    self.block3[-1].conv3.out_channels,
                    self.block4[-1].conv3.out_channels
                ]

        elif "decoder" in self.blobs_strategy:
            if "var" in self.model_type:
                fpn_sizes = [
                    self.decoder_network[3].conv.out_channels,
                    self.decoder_network[2].conv.out_channels,
                    self.decoder_network[1].conv.out_channels
                ]
            else:
                fpn_sizes = [
                    self.decoder_network[1].conv.out_channels,
                    self.decoder_network[0].conv.out_channels,
                    self.synthesizer[-1].conv.out_channels
                ]

        if "encoder" in self.blobs_strategy and "fused" in self.blobs_strategy:
            self.fpn = PyramidFeatures(args,
                                       fpn_sizes[0],
                                       fpn_sizes[1],
                                       fpn_sizes[2],
                                       fusion_strategy="concat_fuse")
        else:
            self.fpn = PyramidFeatures(args, fpn_sizes[0], fpn_sizes[1],
                                       fpn_sizes[2])

        self.dynamic_strategy = ("fused" not in self.blobs_strategy
                                 and "encoder" in self.blobs_strategy)
        # print("dynamic strat", self.dynamic_strategy)
        self.regressionModel = RegressionModel(256, self.dynamic_strategy)
        self.classificationModel = ClassificationModel(256,
                                                       self.dynamic_strategy)

        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()

        import losses

        self.focalLoss = losses.FocalLoss(self.dynamic_strategy)

        prior = 0.01

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))

        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)

        self.params = nn.Sequential(
            self.fpn,
            self.regressionModel,
            self.classificationModel,
        )
Пример #13
0
 def setUp(self):
     self.outputs = torch.Tensor(([0.1, 0.9], [0.8, 0.2]))
     self.targets = torch.Tensor([1, 0]).long()
     self.criterion = losses.FocalLoss()
Пример #14
0
    def __init__(self, layers, prn_node_count=1024, prn_coeff=2):
        super(poseNet, self).__init__()
        if layers == 101:
            self.fpn = FPN101()
        if layers == 50:
            self.fpn = FPN50()

        ##################################################################################
        # keypoints subnet
        # intermediate supervision
        self.convfin_k2 = nn.Conv2d(256,
                                    19,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)
        self.convfin_k3 = nn.Conv2d(256,
                                    19,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)
        self.convfin_k4 = nn.Conv2d(256,
                                    19,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)
        self.convfin_k5 = nn.Conv2d(256,
                                    19,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)

        # 2 conv(kernel=3x3),change channels from 256 to 128
        self.convt1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.convt2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.convt3 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.convt4 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.convs1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.convs2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.convs3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.convs4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=8,
                                     mode='nearest',
                                     align_corners=None)
        self.upsample2 = nn.Upsample(scale_factor=4,
                                     mode='nearest',
                                     align_corners=None)
        self.upsample3 = nn.Upsample(scale_factor=2,
                                     mode='nearest',
                                     align_corners=None)
        # self.upsample4 = nn.Upsample(size=(120,120),mode='bilinear',align_corners=True)

        self.concat = Concat()
        self.conv2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.convfin = nn.Conv2d(256, 18, kernel_size=1, stride=1, padding=0)

        ##################################################################################
        # detection subnet
        self.regressionModel = RegressionModel(256)
        self.classificationModel = ClassificationModel(256, num_classes=1)
        self.anchors = Anchors()
        self.regressBoxes = BBoxTransform()
        self.clipBoxes = ClipBoxes()
        self.focalLoss = losses.FocalLoss()

        ##################################################################################
        # prn subnet
        self.prn = PRN(prn_node_count, prn_coeff)

        ##################################################################################
        # initialize weights
        self._initialize_weights_norm()
        prior = 0.01
        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))
        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)

        self.freeze_bn()  # from retinanet
Пример #15
0
    def __init__(self, num_classes, block, layers):
        super(ResNet, self).__init__()
        self.inplanes = 64

        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block,
                                       planes=64,
                                       blocks=layers[0],
                                       stride=1)
        self.layer2 = self._make_layer(block,
                                       planes=128,
                                       blocks=layers[1],
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       planes=256,
                                       blocks=layers[2],
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       planes=512,
                                       blocks=layers[3],
                                       stride=2)

        if block == BasicBlock:
            fpn_sizes = [
                self.layer2[layers[1] - 1].conv2.out_channels,
                self.layer3[layers[2] - 1].conv2.out_channels,
                self.layer4[layers[3] - 1].conv2.out_channels
            ]
        elif block == Bottleneck:
            fpn_sizes = [
                self.layer2[layers[1] - 1].conv3.out_channels,
                self.layer3[layers[2] - 1].conv3.out_channels,
                self.layer4[layers[3] - 1].conv3.out_channels
            ]

        # if block == BasicBlock:
        #     fpn_sizes = [self.layer1[layers[1]-1].conv2.out_channels, self.layer2[layers[1]-1].conv2.out_channels, self.layer3[layers[2]-1].conv2.out_channels, self.layer4[layers[3]-1].conv2.out_channels]
        # elif block == Bottleneck:
        #     fpn_sizes = [self.layer1[layers[1]-1].conv2.out_channels, self.layer2[layers[1]-1].conv3.out_channels, self.layer3[layers[2]-1].conv3.out_channels, self.layer4[layers[3]-1].conv3.out_channels]

        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])
        self.regressionModel = RegressionModel(256)
        self.classificationModel = ClassificationModel(256,
                                                       num_classes=num_classes)
        self.siameseNetwork = SiameseNetwork()

        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()

        self.focalLoss = losses.FocalLoss()

        self.cropBoxes = utils.CropBoxes()

        # pooler = Pooler(
        #     output_size=(6, 6),
        #     scales=(1.0/8, 1.0/16, 1.0/32,), #1.0/64, 1.0/128),
        #     sampling_ratio=0,
        #     canonical_level=4,
        # )
        # self.pooler = pooler

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01

        self.classificationModel.conv5.weight.data.fill_(0)
        self.classificationModel.conv5.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))

        self.regressionModel.conv5.weight.data.fill_(0)
        self.regressionModel.conv5.bias.data.fill_(0)

        self.freeze_bn()
Пример #16
0
    def __init__(self, num_class, block, layers):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Sequential(
            OrderedDict([('Conv1',
                          nn.Conv2d(3,
                                    64,
                                    kernel_size=7,
                                    stride=2,
                                    padding=3,
                                    bias=False)), ('BN', nn.BatchNorm2d(64)),
                         ('Relu', nn.ReLU(inplace=True)),
                         ('Maxpooling',
                          nn.MaxPool2d(kernel_size=3, stride=2, padding=1))]))

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        if block == BasicBlock:
            fpn_sizes = [
                self.layer2[layers[1] - 1].channels,
                self.layer3[layers[2] - 1].channels,
                self.layer4[layers[3] - 1].channels
            ]
        elif block == Bottleneck:
            fpn_sizes = [
                self.layer2[layers[1] - 1].channels,
                self.layer3[layers[2] - 1].channels,
                self.layer4[layers[3] - 1].channels
            ]

        self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])

        self.regression = Regression(256)
        self.classification = Classification(256, num_classes=num_class)

        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()
        self.clipBoxes = ClipBoxes()

        self.focalLoss = losses.FocalLoss()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01

        # self.classification.output.weight.data.fill_(0)
        self.classification.output.bias.data.fill_(-torch.log(
            (torch.tensor(1.0 - prior).float()) / prior))

        # self.regression.output.weight.data.fill_(0)
        self.regression.output.bias.data.fill_(0)

        self.freeze_bn()
Пример #17
0
     #print (xte.shape, yte.shape, te_pidx.shape)
 
     # Create a folder to save
     sspath = f'{spath}fold{cv}/'
     if not os.path.exists(sspath):
         os.makedirs(sspath, exist_ok=True)
     
     # Network
     #cnn = models.EEGNet(bias=False, F1=8, D=2)
     cnn = models.sub_EEGNet(bias=False, F1=8, D=2)
     #cnn = models.CCRNN(nSeg=segsize)
     if use_cuda:
         cnn = cnn.cuda()
     
     # Loss
     focalloss, celoss, f1loss = losses.FocalLoss(), nn.CrossEntropyLoss(), losses.F1_Loss()
 
     # Optimizer
     optimizer = optim.Adam(cnn.parameters(), lr=lr, weight_decay=5e-4)
     
     # Train the network
     trloss, tracc = [], []
     teloss, teacc, tepred, tef1, tsacc = [], [], [], [], []
     for ep in range(epochs):
         sname = f'{sspath}{ep}.tar'
         cnn.train()
         
         trloss_, tracc_ = [], []
         for i, data in enumerate(tensor_tr): #training iteration
             x, y = data
             
Пример #18
0
def main(args=None):

    parser     = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset', help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument('--csv_train', help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes', help='Path to file containing class list (see readme)')
    parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')

    parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=50)
    parser.add_argument('--epochs', help='Number of epochs', type=int, default=100)
    parser.add_argument('--attention', help='use attention version', action='store_true')

    parser = parser.parse_args(args)

    # Create the data loaders
    if parser.dataset == 'coco':

        if parser.coco_path is None:
            raise ValueError('Must provide --coco_path when training on COCO,')

        dataset_train = CocoDataset(parser.coco_path, set_name='train2017', transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))
        dataset_val = CocoDataset(parser.coco_path, set_name='val2017', transform=transforms.Compose([Normalizer(), Resizer()]))

    elif parser.dataset == 'csv':

        if parser.csv_train is None:
            raise ValueError('Must provide --csv_train when training on COCO,')

        if parser.csv_classes is None:
            raise ValueError('Must provide --csv_classes when training on COCO,')


        dataset_train = CSVDataset(train_file=parser.csv_train, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Resizer()]))

    else:
        raise ValueError('Dataset type not understood (must be csv or coco), exiting.')

    sampler = AspectRatioBasedSampler(dataset_train, batch_size=1, drop_last=False)
    dataloader_train = DataLoader(dataset_train, num_workers=3, collate_fn=collater, batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
        dataloader_val = DataLoader(dataset_val, num_workers=3, collate_fn=collater, batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 50:
        if parser.attention:
            retinanet = model.attention_resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
        else:
            retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(), pretrained=True)
    else:
        raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.module.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))

    focalLoss = losses.FocalLoss()

    for epoch_num in range(parser.epochs):

        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()

                #classification_loss, regression_loss = retinanet([data['img'].cuda().float(), data['annot']])
                classification, regression, anchors, annotations = retinanet([data['img'].cuda().float(), data['annot']])
                classification_loss, regression_loss = focalLoss(classification, regression, anchors, annotations)

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss

                if bool(loss == 0):
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

                optimizer.step()

                loss_hist.append(float(loss))

                epoch_loss.append(float(loss))

                print('Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'.format(epoch_num, iter_num, float(classification_loss), float(regression_loss), np.mean(loss_hist)))

                del classification_loss
                del regression_loss
            except Exception as e:
                print(e)
                continue

        if parser.dataset == 'coco':

            print('Evaluating dataset')

            coco_eval.evaluate_coco(dataset_val, retinanet)

        elif parser.dataset == 'csv' and parser.csv_val is not None:

            print('Evaluating dataset')

            mAP = csv_eval.evaluate(dataset_val, retinanet)


        scheduler.step(np.mean(epoch_loss))

        torch.save(retinanet.module, '{}_retinanet_{}.pt'.format(parser.dataset, epoch_num))

    retinanet.eval()

    torch.save(retinanet, 'model_final.pt'.format(epoch_num))
Пример #19
0
def evaluate_coco(
    dataset,
    model,
    threshold=0.05,
    use_gpu=True,
    save=False,
    w_class=1.0,
    w_regr=1.0,
    w_sem=1.0,
    num_classes=2,
    use_n_samples=None,
    returntype='dict',
    coco_header=None,
):
    model.eval()
    print("model.training", model.training)

    loss_func_bbox = losses.FocalLoss()
    loss_func_semantic_xe = nn.CrossEntropyLoss(reduce=True, size_average=True)

    mean_loss_total = 0.0
    mean_loss_class = 0.0
    mean_loss_regr = 0.0
    mean_loss_sem = 0.0
    mean_ious = [0.0] * num_classes
    mean_nboxes = 0.0

    if use_n_samples is None:
        use_n_samples = len(dataset)

    with torch.no_grad():

        # start collecting results
        results = []
        results_semantic = []
        image_ids = []

        for index in range(use_n_samples):
            data = dataset[index]
            if 'scale' in data:
                scale = data['scale']
            else:
                scale = 1.0

            # run network
            img = torch.FloatTensor(data['img'])
            img = img.permute(2, 0, 1)
            msk_npy = data['mask'][np.newaxis]
            msk = torch.LongTensor(data['mask'][np.newaxis])
            annot = np.array(data['annot'][np.newaxis])
            #print('annot', annot.dtype, annot.shape)
            if annot.shape[1] > 0:
                annot = torch.FloatTensor(annot)
            else:
                annot = torch.ones((1, 1, 5)) * -1

            if use_gpu:
                img = img.cuda()
                msk = msk.cuda()
                annot = annot.cuda()

            classifications, regressions, anchors, semantic_logits, scores, labels, boxes =\
                model(img.float().unsqueeze(dim=0))

            # SEMANTIC SEGMENTATION
            semantic_loss = loss_func_semantic_xe(semantic_logits,
                                                  msk)  #/ nelements
            ## CONVERT LOGITS TO PROBABLILITIES
            semantic_prob = nn.Softmax2d()(semantic_logits)
            semantic_prob = semantic_prob.detach()  #.cpu().numpy()
            iou_ = losses.sparse_iou_pt(msk, semantic_prob,
                                        reduce=False).cpu().detach().tolist()
            results_semantic.append({
                'image_id': dataset.image_ids[index],
                'iou': iou_
            })
            ##
            classification_loss, regression_loss =\
                loss_func_bbox(classifications, regressions,
                               anchors, annot)
            classification_loss = float(classification_loss.cpu().detach())
            regression_loss = float(regression_loss.cpu().detach())
            semantic_loss = float(semantic_loss.cpu().detach())

            loss = w_class * classification_loss + \
                   w_regr * regression_loss + \
                   w_sem * semantic_loss

            mean_loss_total = upd_mean(mean_loss_total, loss, index)
            mean_loss_class = upd_mean(mean_loss_class, classification_loss,
                                       index)
            mean_loss_regr = upd_mean(mean_loss_regr, regression_loss, index)
            mean_loss_sem = upd_mean(mean_loss_sem, semantic_loss, index)
            mean_ious = [
                upd_mean(mu, float(iou__), index)
                for mu, iou__ in zip(mean_ious, iou_)
            ]

            mean_nboxes = upd_mean(mean_nboxes, int(boxes.shape[0]), index)
            #print("iou", iou_)
            #if len(results_semantic)>1:
            #    break
            if len(boxes.shape) == 1:
                print("no boxes predicted for the instance %d\tid = %s" %
                      (index, dataset.image_ids[index]))
                print(data.keys())
                print("skipping")
                continue
            scores = scores.cpu()
            labels = labels.cpu()
            boxes = boxes.cpu()

            # correct boxes for image scale
            boxes /= scale

            if boxes.shape[0] > 0:
                # change to (x, y, w, h) (MS COCO standard)
                boxes[:, 2] -= boxes[:, 0]
                boxes[:, 3] -= boxes[:, 1]

                # compute predicted labels and scores
                #for box, score, label in zip(boxes[0], scores[0], labels[0]):
                for box_id in range(boxes.shape[0]):
                    score = float(scores[box_id])
                    label = int(labels[box_id])
                    box = boxes[box_id, :]

                    # scores are sorted, so we can break
                    if score < threshold:
                        break

                    # append detection for each positively labeled class
                    image_result = {
                        'image_id': dataset.image_ids[index],
                        'category_id': dataset.label_to_coco_label(label),
                        'score': float(score),
                        'bbox': box.tolist(),
                    }

                    # append detection to results
                    results.append(image_result)

            # append image to list of processed images
            image_ids.append(dataset.image_ids[index])

            # print progress
            print('{}/{}'.format(index, len(dataset)), end='\r')

        if not len(results):
            return {}

        loss_summary_dict = OrderedDict([
            ("loss_total", float(mean_loss_total)),
            ("loss_class", float(mean_loss_class)),
            ("loss_regr", float(mean_loss_regr)),
            ('mean_nboxes', float(mean_nboxes)),
            ("loss_sem", float(mean_loss_sem)),
        ])
        loss_summary_dict.update({("iou_%d" % (ii + 1)): vv
                                  for ii, vv in enumerate(mean_ious)})

        logstr = [ "Loss:\tTotal: {:.4f}\tClass: {:.4f}\tRegr: {:.4f}\tSemantic: {:.4f}" ] +\
                ["\tIOU#{:d}: {{:.3f}}".format(n+1) for n in range(num_classes)]
        logstr = "".join(logstr)
        print(
            logstr.format(mean_loss_total, mean_loss_class, mean_loss_regr,
                          mean_loss_sem, *mean_ious))

        if save:
            # write output
            json.dump(results,
                      open('{}_bbox_results.json'.format(dataset.set_name),
                           'w'),
                      indent=4)
            json.dump(results_semantic,
                      open('{}_semantic_results.json'.format(dataset.set_name),
                           'w'),
                      indent=4)

        # load results in COCO evaluation tool
        coco_true = dataset.coco
        coco_pred = coco_true.loadRes(results)
        #coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(dataset.set_name))

        # run COCO evaluation
        coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
        coco_eval.params.imgIds = image_ids
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()

        model.train()

        if returntype == 'dict':
            if coco_header is None:
                coco_header = get_header(coco_eval)
            apar_summary_dict = OrderedDict(zip(coco_header, coco_eval.stats))
            loss_summary_dict.update(apar_summary_dict)
            return loss_summary_dict
        else:
            return coco_eval, loss_summary_dict
Пример #20
0
    else:
        retinanet.no_rpn = False
        logstr = '''Ep#{} | Iter#{:%d}/{:%d} || Losses | Class: {:1.4f} | Regr: {:1.4f} | Sem: {:1.5f} | Running: {:1.4f}''' % (
            ndigits, ndigits)

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(),
                           lr=parser.lr,
                           weight_decay=parser.weight_decay)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     verbose=True)

    loss_func_bbox = losses.FocalLoss()
    loss_func_semantic_xe = nn.CrossEntropyLoss(reduce=True, size_average=True)

    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))
    logdir = "checkpoints/{}".format(arghash)
    if (not parser.overwrite) and os.path.exists(logdir) and \
            sum((1 for x in os.scandir(logdir) if x.name.endswith('.pt'))):
        raise RuntimeError("directory exists and non empty:\t%s" % logdir)
    os.makedirs(logdir, exist_ok=True)
    parser.to_yaml(os.path.join(logdir, 'checkpoint.info'))
Пример #21
0
# %%
checkpointname = './best_model' + str(num_classes) + 'CD.pkl'
import os.path

if os.path.exists((checkpointname)):
    checkpoint = torch.load(checkpointname)
    change_net.load_state_dict(checkpoint)
    print('Checkpoint ' + checkpointname + ' is loaded.')

# #### Initialize Loss Functions and Optimizers

# %%
#criterion = nn.CrossEntropyLoss()
# If there are more than 2 classes the alpha need to be a list
criterion = losses.FocalLoss(gamma=2.0, alpha=0.25)
optimizer = optim.Adam(change_net.parameters(), lr=base_lr)
sc_plt = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                    patience=3,
                                                    verbose=True)

# #### Train Model

# %%
best_model, _ = utils_train.train_model(change_net,
                                        dataloaders_dict,
                                        criterion,
                                        optimizer,
                                        sc_plt,
                                        writer,
                                        device,
Пример #22
0
def main():
    file_name = "./flood_graph/150_250/128/500/ji_sort/1_conf/sample-wised/default/{}/".format(
        args.b)
    start = time.time()
    # set GPU ID
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    cudnn.benchmark = True

    # check save path
    save_path = file_name
    # save_path = args.save_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # make dataloader
    if args.valid == True:
        train_loader, valid_loader, test_loader, test_onehot, test_label = dataset.get_valid_loader(
            args.data, args.data_path, args.batch_size)

    else:
        train_loader, train_onehot, train_label, test_loader, test_onehot, test_label = dataset.get_loader(
            args.data, args.data_path, args.batch_size)

    # set num_class
    if args.data == 'cifar100':
        num_class = 100
    else:
        num_class = 10

    # set num_classes
    model_dict = {
        "num_classes": num_class,
    }

    # set model
    if args.model == 'res':
        model = resnet.resnet110(**model_dict).cuda()
    elif args.model == 'dense':
        model = densenet_BC.DenseNet3(depth=100,
                                      num_classes=num_class,
                                      growth_rate=12,
                                      reduction=0.5,
                                      bottleneck=True,
                                      dropRate=0.0).cuda()
    elif args.model == 'vgg':
        model = vgg.vgg16(**model_dict).cuda()

    # set criterion
    if args.loss == 'MS':
        cls_criterion = losses.MultiSimilarityLoss().cuda()
    elif args.loss == 'Contrastive':
        cls_criterion = losses.ContrastiveLoss().cuda()
    elif args.loss == 'Triplet':
        cls_criterion = losses.TripletLoss().cuda()
    elif args.loss == 'NPair':
        cls_criterion = losses.NPairLoss().cuda()
    elif args.loss == 'Focal':
        cls_criterion = losses.FocalLoss(gamma=3.0).cuda()
    else:
        if args.mode == 0:
            cls_criterion = nn.CrossEntropyLoss().cuda()
        else:
            cls_criterion = nn.CrossEntropyLoss(reduction="none").cuda()

    ranking_criterion = nn.MarginRankingLoss(margin=0.0).cuda()

    # set optimizer (default:sgd)
    optimizer = optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=5e-4,
        # weight_decay=0.0001,
        nesterov=False)

    # optimizer = optim.SGD(model.parameters(),
    #                       lr=float(args.lr),
    #                       momentum=0.9,
    #                       weight_decay=args.weight_decay,
    #                       nesterov=False)

    # set scheduler
    # scheduler = MultiStepLR(optimizer,
    #                         milestones=[500, 750],
    #                         gamma=0.1)

    scheduler = MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)

    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay_gamma)

    # make logger
    train_logger = utils.Logger(os.path.join(save_path, 'train.log'))
    result_logger = utils.Logger(os.path.join(save_path, 'result.log'))

    # make History Class
    correctness_history = crl_utils.History(len(train_loader.dataset))

    ## define matrix
    if args.data == 'cifar':
        matrix_idx_confidence = [[_] for _ in range(50000)]
        matrix_idx_iscorrect = [[_] for _ in range(50000)]
    else:
        matrix_idx_confidence = [[_] for _ in range(73257)]
        matrix_idx_iscorrect = [[_] for _ in range(73257)]

    # write csv
    #'''
    import csv
    f = open('{}/logs_{}_{}.txt'.format(file_name, args.b, args.epochs),
             'w',
             newline='')
    f.write("location = {}\n\n".format(file_name) + str(args))

    f0 = open('{}/Test_confidence_{}_{}.csv'.format(file_name, args.b,
                                                    args.epochs),
              'w',
              newline='')
    # f0 = open('./baseline_graph/150_250/128/500/Test_confidence_{}_{}.csv'.format(args.b, args.epochs), 'w', newline='')
    # f0 = open('./CRL_graph/150_250/Test_confidence_{}_{}.csv'.format(args.b, args.epochs), 'w', newline='')

    wr_conf_test = csv.writer(f0)
    header = [_ for _ in range(args.epochs + 1)]
    header[0] = 'Epoch'
    wr_conf_test.writerows([header])

    f1 = open('{}/Train_confidence_{}_{}.csv'.format(file_name, args.b,
                                                     args.epochs),
              'w',
              newline='')
    # f1 = open('./baseline_graph/150_250/128/500/Train_confidence_{}_{}.csv'.format(args.b, args.epochs), 'w', newline='')
    # f1 = open('./CRL_graph/150_250/Train_confidence_{}_{}.csv'.format(args.b, args.epochs), 'w', newline='')

    wr = csv.writer(f1)
    header = [_ for _ in range(args.epochs + 1)]
    header[0] = 'Epoch'
    wr.writerows([header])

    f2 = open('{}/Train_Flood_{}_{}_{}.csv'.format(file_name, args.data,
                                                   args.b, args.epochs),
              'w',
              newline='')
    # f2 = open('./baseline_graph/150_250/128/500/Train_Base_{}_{}_{}.csv'.format(args.data, args.b, args.epochs), 'w', newline='')
    # f2 = open('./CRL_graph/150_250/Train_Flood_{}_{}_{}.csv'.format(args.data, args.b, args.epochs), 'w', newline='')

    wr_train = csv.writer(f2)
    header = [_ for _ in range(args.epochs + 1)]
    header[0] = 'Epoch'
    wr_train.writerows([header])

    f3 = open('{}/Test_Flood_{}_{}_{}.csv'.format(file_name, args.data, args.b,
                                                  args.epochs),
              'w',
              newline='')
    # f3 = open('./baseline_graph/150_250/128/500/Test_Base_{}_{}_{}.csv'.format(args.data, args.b, args.epochs), 'w', newline='')
    # f3 = open('./CRL_graph/150_250/Test_Flood_{}_{}_{}.csv'.format(args.data, args.b, args.epochs), 'w', newline='')

    wr_test = csv.writer(f3)
    header = [_ for _ in range(args.epochs + 1)]
    header[0] = 'Epoch'
    wr_test.writerows([header])
    #'''

    # start Train
    best_valid_acc = 0
    test_ece_report = []
    test_acc_report = []
    test_nll_report = []
    test_over_con99_report = []
    test_e99_report = []
    test_cls_loss_report = []

    train_ece_report = []
    train_acc_report = []
    train_nll_report = []
    train_over_con99_report = []
    train_e99_report = []
    train_cls_loss_report = []
    train_rank_loss_report = []
    train_total_loss_report = []

    for epoch in range(1, args.epochs + 1):
        scheduler.step()

        matrix_idx_confidence, matrix_idx_iscorrect, idx, iscorrect, confidence, target, cls_loss_tr, rank_loss_tr, batch_correctness, total_confidence, total_correctness = \
            train.train(matrix_idx_confidence, matrix_idx_iscorrect, train_loader,
                    model,
                    wr,
                    cls_criterion,
                    ranking_criterion,
                    optimizer,
                    epoch,
                    correctness_history,
                    train_logger,
                    args)

        if args.rank_weight != 0.0:
            print("RANK ", rank_loss_tr)
            total_loss_tr = cls_loss_tr + rank_loss_tr

        if args.valid == True:
            idx, iscorrect, confidence, target, cls_loss_val, acc = train.valid(
                valid_loader, model, cls_criterion, ranking_criterion,
                optimizer, epoch, correctness_history, train_logger, args)
            if acc > best_valid_acc:
                best_valid_acc = acc
                print("*** Update Best Acc ***")

        # save model
        if epoch == args.epochs:
            torch.save(model.state_dict(),
                       os.path.join(save_path, 'model.pth'))

        print("########### Train ###########")
        acc_tr, aurc_tr, eaurc_tr, aupr_tr, fpr_tr, ece_tr, nll_tr, brier_tr, E99_tr, over_99_tr, cls_loss_tr = metrics.calc_metrics(
            train_loader, train_label, train_onehot, model, cls_criterion,
            args)

        if args.sort == True and epoch == 260:
            #if args.sort == True:
            train_loader = dataset.sort_get_loader(
                args.data, args.data_path, args.batch_size, idx,
                np.array(target), iscorrect,
                batch_correctness, total_confidence, total_correctness,
                np.array(confidence), epoch, args)

        train_acc_report.append(acc_tr)
        train_nll_report.append(nll_tr * 10)
        train_ece_report.append(ece_tr)
        train_over_con99_report.append(over_99_tr)
        train_e99_report.append(E99_tr)
        train_cls_loss_report.append(cls_loss_tr)

        if args.rank_weight != 0.0:
            train_total_loss_report.append(total_loss_tr)
            train_rank_loss_report.append(rank_loss_tr)
        print("CLS ", cls_loss_tr)

        # finish train
        print("########### Test ###########")
        # calc measure
        acc_te, aurc_te, eaurc_te, aupr_te, fpr_te, ece_te, nll_te, brier_te, E99_te, over_99_te, cls_loss_te = metrics.calc_metrics(
            test_loader, test_label, test_onehot, model, cls_criterion, args)
        test_ece_report.append(ece_te)
        test_acc_report.append(acc_te)
        test_nll_report.append(nll_te * 10)
        test_over_con99_report.append(over_99_te)
        test_e99_report.append(E99_te)
        test_cls_loss_report.append(cls_loss_te)

        print("CLS ", cls_loss_te)
        print("############################")

    # for idx in matrix_idx_confidence:
    #     wr.writerow(idx)

    #'''
    # draw graph
    df = pd.DataFrame()
    df['epoch'] = [i for i in range(1, args.epochs + 1)]
    df['test_ece'] = test_ece_report
    df['train_ece'] = train_ece_report
    fig_loss = plt.figure(figsize=(35, 35))
    fig_loss.set_facecolor('white')
    ax = fig_loss.add_subplot()

    ax.plot(df['epoch'],
            df['test_ece'],
            df['epoch'],
            df['train_ece'],
            linewidth=10)
    ax.legend(['Test', 'Train'], loc=2, prop={'size': 60})
    plt.title('[FL] ECE per epoch', fontsize=80)
    # plt.title('[BASE] ECE per epoch', fontsize=80)
    # plt.title('[CRL] ECE per epoch', fontsize=80)
    plt.xlabel('Epoch', fontsize=70)
    plt.ylabel('ECE', fontsize=70)
    plt.ylim([0, 1])
    plt.setp(ax.get_xticklabels(), fontsize=30)
    plt.setp(ax.get_yticklabels(), fontsize=30)
    plt.savefig('{}/{}_{}_ECE_lr_{}.png'.format(file_name, args.model, args.b,
                                                args.epochs))
    # plt.savefig('./baseline_graph/150_250/128/500/{}_{}_ECE_lr_{}.png'.format(args.model, args.b, args.epochs))
    # plt.savefig('./CRL_graph/150_250/{}_{}_ECE_lr_{}.png'.format(args.model, args.b, args.epochs))

    df2 = pd.DataFrame()
    df2['epoch'] = [i for i in range(1, args.epochs + 1)]
    df2['test_acc'] = test_acc_report
    df2['train_acc'] = train_acc_report
    fig_acc = plt.figure(figsize=(35, 35))
    fig_acc.set_facecolor('white')
    ax = fig_acc.add_subplot()

    ax.plot(df2['epoch'],
            df2['test_acc'],
            df2['epoch'],
            df2['train_acc'],
            linewidth=10)
    ax.legend(['Test', 'Train'], loc=2, prop={'size': 60})
    plt.title('[FL] Accuracy per epoch', fontsize=80)
    # plt.title('[BASE] Accuracy per epoch', fontsize=80)
    # plt.title('[CRL] Accuracy per epoch', fontsize=80)
    plt.xlabel('Epoch', fontsize=70)
    plt.ylabel('Accuracy', fontsize=70)
    plt.ylim([0, 100])
    plt.setp(ax.get_xticklabels(), fontsize=30)
    plt.setp(ax.get_yticklabels(), fontsize=30)
    plt.savefig('{}/{}_{}_acc_lr_{}.png'.format(file_name, args.model, args.b,
                                                args.epochs))
    # plt.savefig('./baseline_graph/150_250/128/500/{}_{}_acc_lr_{}.png'.format(args.model, args.b, args.epochs))
    # plt.savefig('./CRL_graph/150_250/{}_{}_acc_lr_{}.png'.format(args.model, args.b, args.epochs))

    df3 = pd.DataFrame()
    df3['epoch'] = [i for i in range(1, args.epochs + 1)]
    df3['test_nll'] = test_nll_report
    df3['train_nll'] = train_nll_report
    fig_acc = plt.figure(figsize=(35, 35))
    fig_acc.set_facecolor('white')
    ax = fig_acc.add_subplot()

    ax.plot(df3['epoch'],
            df3['test_nll'],
            df3['epoch'],
            df3['train_nll'],
            linewidth=10)
    ax.legend(['Test', 'Train'], loc=2, prop={'size': 60})
    plt.title('[FL] NLL per epoch', fontsize=80)
    # plt.title('[BASE] NLL per epoch', fontsize=80)
    # plt.title('[CRL] NLL per epoch', fontsize=80)
    plt.xlabel('Epoch', fontsize=70)
    plt.ylabel('NLL', fontsize=70)
    plt.ylim([0, 45])
    plt.setp(ax.get_xticklabels(), fontsize=30)
    plt.setp(ax.get_yticklabels(), fontsize=30)
    plt.savefig('{}/{}_{}_nll_lr_{}.png'.format(file_name, args.model, args.b,
                                                args.epochs))
    # plt.savefig('./baseline_graph/150_250/128/500/{}_{}_nll_lr_{}.png'.format(args.model, args.b, args.epochs))
    # plt.savefig('./CRL_graph/150_250/{}_{}_nll_lr_{}.png'.format(args.model, args.b, args.epochs))

    df4 = pd.DataFrame()
    df4['epoch'] = [i for i in range(1, args.epochs + 1)]
    df4['test_over_con99'] = test_over_con99_report
    df4['train_over_con99'] = train_over_con99_report
    fig_acc = plt.figure(figsize=(35, 35))
    fig_acc.set_facecolor('white')
    ax = fig_acc.add_subplot()

    ax.plot(df4['epoch'],
            df4['test_over_con99'],
            df4['epoch'],
            df4['train_over_con99'],
            linewidth=10)
    ax.legend(['Test', 'Train'], loc=2, prop={'size': 60})
    plt.title('[FL] Over conf99 per epoch', fontsize=80)
    # plt.title('[BASE] Over conf99 per epoch', fontsize=80)
    # plt.title('[CRL] Over conf99 per epoch', fontsize=80)
    plt.xlabel('Epoch', fontsize=70)
    plt.ylabel('Over con99', fontsize=70)
    if args.data == 'cifar10' or args.data == 'cifar100':
        plt.ylim([0, 50000])
    else:
        plt.ylim([0, 73257])

    plt.setp(ax.get_xticklabels(), fontsize=30)
    plt.setp(ax.get_yticklabels(), fontsize=30)
    plt.savefig('{}/{}_{}_over_conf99_lr_{}.png'.format(
        file_name, args.model, args.b, args.epochs))
    # plt.savefig('./baseline_graph/150_250/128/500/{}_{}_over_conf99_lr_{}.png'.format(args.model, args.b, args.epochs))
    # plt.savefig('./CRL_graph/150_250/{}_{}_over_conf99_lr_{}.png'.format(args.model, args.b, args.epochs))

    df5 = pd.DataFrame()
    df5['epoch'] = [i for i in range(1, args.epochs + 1)]
    df5['test_e99'] = test_e99_report
    df5['train_e99'] = train_e99_report
    fig_acc = plt.figure(figsize=(35, 35))
    fig_acc.set_facecolor('white')
    ax = fig_acc.add_subplot()

    ax.plot(df5['epoch'],
            df5['test_e99'],
            df5['epoch'],
            df5['train_e99'],
            linewidth=10)
    ax.legend(['Test', 'Train'], loc=2, prop={'size': 60})
    plt.title('[FL] E99 per epoch', fontsize=80)
    # plt.title('[BASE] E99 per epoch', fontsize=80)
    # plt.title('[CRL] E99 per epoch', fontsize=80)
    plt.xlabel('Epoch', fontsize=70)
    plt.ylabel('E99', fontsize=70)
    plt.ylim([0, 0.2])
    plt.setp(ax.get_xticklabels(), fontsize=30)
    plt.setp(ax.get_yticklabels(), fontsize=30)
    plt.savefig('{}/{}_{}_E99_flood_lr_{}.png'.format(file_name, args.model,
                                                      args.b, args.epochs))
    # plt.savefig('./baseline_graph/150_250/128/500/{}_{}_E99_flood_lr_{}.png'.format(args.model, args.b, args.epochs))
    # plt.savefig('./CRL_graph/150_250/{}_{}_E99_flood_lr_{}.png'.format(args.model, args.b, args.epochs))

    df5 = pd.DataFrame()
    df5['epoch'] = [i for i in range(1, args.epochs + 1)]
    df5['test_cls_loss'] = test_cls_loss_report
    df5['train_cls_loss'] = train_cls_loss_report
    fig_acc = plt.figure(figsize=(35, 35))
    fig_acc.set_facecolor('white')
    ax = fig_acc.add_subplot()

    ax.plot(df5['epoch'],
            df5['test_cls_loss'],
            df5['epoch'],
            df5['train_cls_loss'],
            linewidth=10)
    ax.legend(['Test', 'Train'], loc=2, prop={'size': 60})
    plt.title('[FL] CLS_loss per epoch', fontsize=80)
    # plt.title('[BASE] CLS_loss per epoch', fontsize=80)
    # plt.title('[CRL] CLS_loss per epoch', fontsize=80)
    plt.xlabel('Epoch', fontsize=70)
    plt.ylabel('Loss', fontsize=70)
    plt.ylim([0, 5])
    plt.setp(ax.get_xticklabels(), fontsize=30)
    plt.setp(ax.get_yticklabels(), fontsize=30)
    plt.savefig('{}/{}_{}_cls_loss_flood_lr_{}.png'.format(
        file_name, args.model, args.b, args.epochs))
    # plt.savefig('./baseline_graph/150_250/128/500/{}_{}_cls_loss_flood_lr_{}.png'.format(args.model, args.b, args.epochs))
    # plt.savefig('./CRL_graph/150_250/{}_{}_cls_loss_flood_lr_{}.png'.format(args.model, args.b, args.epochs))

    if args.rank_weight != 0.0:
        df6 = pd.DataFrame()
        df6['epoch'] = [i for i in range(1, args.epochs + 1)]
        df6['train_cls_loss'] = train_cls_loss_report
        df6['train_rank_loss'] = train_rank_loss_report
        df6['train_total_loss'] = train_total_loss_report
        fig_acc = plt.figure(figsize=(35, 35))
        fig_acc.set_facecolor('white')
        ax = fig_acc.add_subplot()

        ax.plot(df6['epoch'],
                df6['train_cls_loss'],
                df6['epoch'],
                df6['train_rank_loss'],
                df6['epoch'],
                df6['train_total_loss'],
                linewidth=10)
        ax.legend(['CLS', 'Rank', 'Total'], loc=2, prop={'size': 60})
        plt.title('[FL] CLS_loss per epoch', fontsize=80)
        plt.xlabel('Epoch', fontsize=70)
        plt.ylabel('Loss', fontsize=70)
        # plt.ylim([0, 5])
        plt.setp(ax.get_xticklabels(), fontsize=30)
        plt.setp(ax.get_yticklabels(), fontsize=30)
        plt.savefig(
            './CRL_graph/150_250/{}_{}_cls_loss_flood_lr_{}.png'.format(
                args.model, args.b, args.epochs))

    test_acc_report.insert(0, 'ACC')
    test_ece_report.insert(0, 'ECE')
    test_nll_report.insert(0, 'NLL')
    test_over_con99_report.insert(0, 'Over_conf99')
    test_e99_report.insert(0, 'E99')
    test_cls_loss_report.insert(0, 'CLS')
    wr_test.writerow(test_acc_report)
    wr_test.writerow(test_ece_report)
    wr_test.writerow(test_nll_report)
    wr_test.writerow(test_over_con99_report)
    wr_test.writerow(test_e99_report)
    wr_test.writerow(test_cls_loss_report)

    train_acc_report.insert(0, 'ACC')
    train_ece_report.insert(0, 'ECE')
    train_nll_report.insert(0, 'NLL')
    train_over_con99_report.insert(0, 'Over_conf99')
    train_e99_report.insert(0, 'E99')
    train_cls_loss_report.insert(0, 'CLS')

    wr_train.writerow(train_acc_report)
    wr_train.writerow(train_ece_report)
    wr_train.writerow(train_nll_report)
    wr_train.writerow(train_over_con99_report)
    wr_train.writerow(train_e99_report)
    wr_train.writerow(train_cls_loss_report)

    if args.rank_weight != 0.0:
        train_rank_loss_report.insert(0, 'Rank')
        train_total_loss_report.insert(0, 'Total')
        wr_train.writerow(train_rank_loss_report)
        wr_train.writerow(train_total_loss_report)

    #'''

    # result write
    result_logger.write([
        acc_te, aurc_te * 1000, eaurc_te * 1000, aupr_te * 100, fpr_te * 100,
        ece_te * 100, nll_te * 10, brier_te * 100, E99_te * 100
    ])
    if args.valid == True:
        print("Best Valid Acc : {}".format(acc))
    print("Flood Level: {}".format(args.b))
    print("Sort : {}".format(args.sort))
    print("Sort Mode : {}".format(args.sort_mode))
    print("TIME : ", time.time() - start)