コード例 #1
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_mobilev1_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        model.load(args.resume)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    # -----------------------------------------------------------------------------
    # Criterion
    # -----------------------------------------------------------------------------
    criterion = MultiBoxLoss(neg_pos_ratio=cfg.MODEL.NEG_POS_RATIO)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE, cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE, cfg.MODEL.THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN, transform=train_transform, target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset, num_workers=4, batch_sampler=batch_sampler)

    return do_train(cfg, model, train_loader, optimizer, scheduler, criterion, device, args)
コード例 #2
0
ファイル: ssd.py プロジェクト: deadpoppy/SSD1024
    def forward(self, x, targets=None):
        sources = []
        confidences = []
        locations = []
        for i in range(1):
            x = self.vgg[i](x)
        s = self.l2_norm(x)  # Conv4_3 L2 normalization
        sources.append(s)

        # apply vgg up to fc7
        for i in range(1, len(self.vgg)):
            x = self.vgg[i](x)
            sources.append(x)

        # for k, v in enumerate(self.extras):
        #     x = F.relu(v(x), inplace=True)
        #     if k % 2 == 1:
        #         sources.append(x)

        # for aaa in sources:
        #     print(aaa.shape)

        for (x, l, c) in zip(sources, self.regression_headers, self.classification_headers):
            locations.append(l(x).permute(0, 2, 3, 1).contiguous())
            confidences.append(c(x).permute(0, 2, 3, 1).contiguous())

        confidences = torch.cat([o.view(o.size(0), -1) for o in confidences], 1)
        locations = torch.cat([o.view(o.size(0), -1) for o in locations], 1)

        confidences = confidences.view(confidences.size(0), -1, self.num_classes)
        locations = locations.view(locations.size(0), -1, 4)

        if not self.training:
            # when evaluating, decode predictions
            if self.priors is None:
                self.priors = PriorBox(self.cfg)().to(locations.device)
            confidences = F.softmax(confidences, dim=2)
            boxes = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.cfg.MODEL.CENTER_VARIANCE, self.cfg.MODEL.SIZE_VARIANCE
            )
            boxes = box_utils.center_form_to_corner_form(boxes)
            return confidences, boxes
        else:
            # when training, compute losses
            gt_boxes, gt_labels = targets
            regression_loss, classification_loss = self.criterion(confidences, locations, gt_labels, gt_boxes)
            loss_dict = dict(
                regression_loss=regression_loss,
                classification_loss=classification_loss,
            )
            return loss_dict
コード例 #3
0
def setup_self_ade(cfg, args):
    logger = logging.getLogger("self_ade.setup")
    logger.info("Starting self_ade setup")

    # build model from config
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN,
                                        cfg.INPUT.PIXEL_STD)

    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)

    test_dataset = build_dataset(dataset_list=cfg.DATASETS.TEST,
                                 is_test=True)[0]
    self_ade_dataset = build_dataset(dataset_list=cfg.DATASETS.TEST,
                                     transform=train_transform,
                                     target_transform=target_transform)
    ss_dataset = SelfSupervisedDataset(self_ade_dataset, cfg)

    test_sampler = SequentialSampler(test_dataset)
    os_sampler = OneSampleBatchSampler(test_sampler, cfg.SOLVER.BATCH_SIZE,
                                       args.self_ade_iterations)

    self_ade_dataloader = DataLoader(ss_dataset,
                                     batch_sampler=os_sampler,
                                     num_workers=args.num_workers)

    effective_lr = args.learning_rate * args.self_ade_weight

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=effective_lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.USE_AMP
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    execute_self_ade(cfg, args, test_dataset, self_ade_dataloader, model,
                     optimizer)
コード例 #4
0
ファイル: ssd.py プロジェクト: tomzhang/SSD
    def forward(self, x):
        sources = []
        confidences = []
        locations = []
        for i in range(23):
            x = self.vgg[i](x)
        s = self.l2_norm(x)  # Conv4_3 L2 normalization
        sources.append(s)

        # apply vgg up to fc7
        for i in range(23, len(self.vgg)):
            x = self.vgg[i](x)
        sources.append(x)

        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)

        for (x, l, c) in zip(sources, self.regression_headers,
                             self.classification_headers):
            locations.append(l(x).permute(0, 2, 3, 1).contiguous())
            confidences.append(c(x).permute(0, 2, 3, 1).contiguous())

        confidences = torch.cat([o.view(o.size(0), -1) for o in confidences],
                                1)
        locations = torch.cat([o.view(o.size(0), -1) for o in locations], 1)

        confidences = confidences.view(confidences.size(0), -1,
                                       self.num_classes)
        locations = locations.view(locations.size(0), -1, 4)

        if self.is_test:
            if self.priors is None:
                self.priors = PriorBox(self.cfg)().to(locations.device)
            confidences = F.softmax(confidences, dim=2)
            boxes = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.cfg.MODEL.CENTER_VARIANCE,
                self.cfg.MODEL.SIZE_VARIANCE)
            boxes = box_utils.center_form_to_corner_form(boxes)

            return confidences, boxes
        else:
            return confidences, locations
コード例 #5
0
def _create_val_datasets(args, cfg, logger):
    dslist = {}
    val_set_dict = {}

    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE, cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
    target_transform = MatchPrior(PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
                                  cfg.MODEL.THRESHOLD)

    default_domain_dataset, default_domain_val_set = build_dataset(
        dataset_list=cfg.DATASETS.DG_SETTINGS.DEFAULT_DOMAIN,
        transform=train_transform, target_transform=target_transform, split=True)
    val_set_dict["Default domain"] = default_domain_val_set
    logger.info("Default domain: train split has {} elements, test split has {} elements".format(
        len(default_domain_dataset), len(default_domain_val_set)))

    dslist["Default domain"] = default_domain_dataset
    for element in cfg.DATASETS.DG_SETTINGS.SOURCE_DOMAINS:
        if not isinstance(element, tuple):
            sets = (element,)
        else:
            sets = element

        if args.eval_mode == "val":
            ds, val_set = build_dataset(dataset_list=sets, transform=train_transform,
                                        target_transform=target_transform, split=True)
            val_set_dict[element] = val_set
            logger.info(
                "Domain {}: train split has {} elements, test split has {} elements".format(str(element), len(ds),
                                                                                            len(val_set)))
        else:
            ds = build_dataset(dataset_list=sets, transform=train_transform,
                               target_transform=target_transform)

        dslist[element] = ds

    return val_set_dict
コード例 #6
0
ファイル: ssd.py プロジェクト: Anonymous4604/Self-ADE_SSD
    def forward(self, x, targets=None, auxiliary_task=False):

        ss_criterion = nn.CrossEntropyLoss()
        sources = []
        confidences = []
        locations = []
        for i in range(23):
            x = self.vgg[i](x)
        s = self.l2_norm(x)  # Conv4_3 L2 normalization
        sources.append(s)

        # apply vgg up to fc7
        for i in range(23, len(self.vgg)):
            x = self.vgg[i](x)
        sources.append(x)

        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)

        # if the auxiliary task is the rotation task we can apply it after the last layer
        if auxiliary_task and self.cfg.MODEL.SELF_SUPERVISOR.TYPE == "rotation":
            jx = x.view(x.size(0), -1)
            j_output = self.ss_classifier(self.ss_dropout(jx))

        if not auxiliary_task:
            for (x, l, c) in zip(sources, self.regression_headers,
                                 self.classification_headers):
                locations.append(l(x).permute(0, 2, 3, 1).contiguous())
                confidences.append(c(x).permute(0, 2, 3, 1).contiguous())

            confidences = torch.cat(
                [o.view(o.size(0), -1) for o in confidences], 1)
            locations = torch.cat([o.view(o.size(0), -1) for o in locations],
                                  1)

            confidences = confidences.view(confidences.size(0), -1,
                                           self.num_classes)
            locations = locations.view(locations.size(0), -1, 4)

        if not self.training:
            # when evaluating, decode predictions
            if self.priors is None:
                self.priors = PriorBox(self.cfg)().to(locations.device)
            confidences = F.softmax(confidences, dim=2)
            boxes = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.cfg.MODEL.CENTER_VARIANCE,
                self.cfg.MODEL.SIZE_VARIANCE)
            boxes = box_utils.center_form_to_corner_form(boxes)
            return confidences, boxes
        else:
            # when training, compute losses
            if auxiliary_task:
                j_index = targets
                j_loss = ss_criterion(j_output, j_index)
                loss_dict = dict(aux_loss=j_loss)
            else:
                gt_boxes, gt_labels = targets
                regression_loss, classification_loss = self.criterion(
                    confidences, locations, gt_labels, gt_boxes)
                loss_dict = dict(
                    regression_loss=regression_loss,
                    classification_loss=classification_loss,
                )
            return loss_dict
コード例 #7
0
    def forward(self, x, targets, score_map=None):
        # print('self.downsample_layers_index',self.downsample_layers_index)
        downsample_feature_map=[]
        sources = []
        confidences = []
        locations = []
        for i in range(23):
            x = self.vgg[i](x)
            # print('x.size():',x.size())
            # if i == 3:
            #     import os
            #     import glob
            #     path = '/home/binchengxiong/ssd_fcn_multitask_text_detection_pytorch1.0/img/tmp/'
            #     for infile in glob.glob(os.path.join(path, '*.jpg')):
            #         os.remove(infile)
            #     sizez = x.size()
            #     print('x.size:',x.size())
            #     for i in range(sizez[1]):
            #         tmp = x[0][i].cpu().numpy()
            #         max = tmp.max()
            #         min = tmp.min()
            #         print('max:',max)
            #         print('min:',min)
            #         featuremap = (tmp - min) / (max - min) * 255
            #
            #         featuremap = featuremap.astype(np.uint8)
            #         featuremap = cv2.applyColorMap(featuremap, cv2.COLORMAP_JET)
            #         cv2.imwrite(
            #             '/home/binchengxiong/ssd_fcn_multitask_text_detection_pytorch1.0/img/tmp/' + str(i) + '.jpg',
            #             featuremap)
            if i in self.downsample_layers_index:
                downsample_feature_map.append(x)
        s = self.l2_norm(x)  # Conv4_3 L2 normalization
        sources.append(s)
        # apply vgg up to fc7
        # print('len(vgg):',len(self.vgg))
        for i in range(23, len(self.vgg)):
            x = self.vgg[i](x)
            # print('x.size():',x.size())
            if i in self.downsample_layers_index:
                downsample_feature_map.append(x)
        sources.append(x)   #Conv_7

        # FCN part
        # for i in downsample_feature_map:
        #     print('i.size:',i.size())
        h = downsample_feature_map[2]  # bs 2048 w/32 h/32,f[3]是最后的输出层
        g = self.unpool1(h) # bs 2048 w/16 h/16
        g = self.unpool1_conv2d(g)
        # print('downsample_feature_map[2].size():',downsample_feature_map[2].size())
        c = self.conv1(g.add_(downsample_feature_map[1]))
        c = self.bn1(c)
        c = self.relu1(c)

        g = self.unpool2(c)  # bs 128 w/8 h/8
        g = self.unpool2_conv2d(g)
        c = self.conv2(g.add_(downsample_feature_map[0]))
        c = self.bn2(c)
        c = self.relu2(c)
        F_score = self.conv3(c)  # bs 1 w/4 h/4
        F_score = self.sigmoid(F_score)
        F_score = torch.squeeze(F_score)
        # print('F_score.size()',F_score.size())
        # print('score_map.size()',score_map.size())

        # for i in downsample_feature_map:
        #     print('i.size():',i.size())

        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            # print('x.size():',x.size())
            if k % 2 == 1:
                sources.append(x) #Conv_8_2,Conv_9_2,Conv_10_2,Conv_11_2

        for (x, l, c) in zip(sources, self.regression_headers, self.classification_headers):
            #原始的feature map的维度是NCHW,permute之后是NHWC
            a = l(x).permute(0, 2, 3, 1).contiguous()
            # print('a.size:',a.size())
            locations.append(a)
            b = c(x).permute(0, 2, 3, 1).contiguous()
            # print('b.size:',b.size())
            confidences.append(b)

        confidences = torch.cat([o.view(o.size(0), -1) for o in confidences], 1)
        locations = torch.cat([o.view(o.size(0), -1) for o in locations], 1)
        #print('locations.size()',locations.size())
        # print('self.num_classes:',self.num_classes)
        confidences = confidences.view(confidences.size(0), -1, self.num_classes)
        # print('confidence.size()',confidences.size())        #[batch_size,24564,2]
        locations = locations.view(locations.size(0), -1, 8)
        #print('locations.size()',locations.size())           #[batch_size,24564,8]

        if not self.training:
            # print('test')
            # when evaluating, decode predictions
            if self.priors is None:
                self.priors = PriorBox(self.cfg)()
            confidences = F.softmax(confidences, dim=2)
            quad = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.cfg.MODEL.CENTER_VARIANCE, self.cfg.MODEL.SIZE_VARIANCE
            )
            score_map = F_score.cpu()
            return confidences, quad,score_map
        else:
            # when training, compute losses
            gt_boxes, gt_labels = targets
            # print('locations:',locations)
            #给了事先匹配好的default box的位置和类别作为真值,回归预测的confidences和locations
            regression_loss, classification_loss = self.criterion(confidences, locations, gt_labels, gt_boxes)
            seg_loss = self.dice_coefficient(score_map,F_score)
            #seg_loss = self.balanced_cross_entropy(score_map,F_score)
            #seg_loss = self.balanced_cross_entropy_1(score_map,F_score)
            loss_dict = dict(
                regression_loss=regression_loss,
                classification_loss=classification_loss,
                fcn_loss=seg_loss
            )
            return loss_dict
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        iteration = checkpoint['iteration']
        print('iteration:', iteration)
    elif args.vgg:
        iteration = 0
        logger.info("Init from backbone net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    else:
        iteration = 0
        logger.info("all init from kaiming init")
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    #optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    print('cfg.SOLVER.WEIGHT_DECAY:', cfg.SOLVER.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # ------------------------1-----------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    #对原始图像进行数据增强
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.IOU_THRESHOLD, cfg.MODEL.PRIORS.DISTANCE_THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                  transform=train_transform,
                                  target_transform=target_transform,
                                  args=args)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    sampler = torch.utils.data.RandomSampler(train_dataset)
    # sampler = torch.utils.data.SequentialSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset,
                              num_workers=4,
                              batch_sampler=batch_sampler,
                              pin_memory=True)

    return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                    args, iteration)
コード例 #9
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Load weights or restore checkpoint
    # -----------------------------------------------------------------------------
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        restore_training_checkpoint(logger,
                                    model,
                                    args.resume,
                                    optimizer=optimizer,
                                    scheduler=scheduler)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.USE_AMP
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN,
                                        cfg.INPUT.PIXEL_STD)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)

    if cfg.DATASETS.DG:
        if args.eval_mode == "val":
            dslist, val_set_dict = _create_dg_datasets(args, cfg, logger,
                                                       target_transform,
                                                       train_transform)
        else:
            dslist = _create_dg_datasets(args, cfg, logger, target_transform,
                                         train_transform)

        logger.info("Sizes of sources datasets:")
        for k, v in dslist.items():
            logger.info("{} size: {}".format(k, len(v)))

        dataloaders = []
        for name, train_dataset in dslist.items():
            sampler = torch.utils.data.RandomSampler(train_dataset)
            batch_sampler = torch.utils.data.sampler.BatchSampler(
                sampler=sampler,
                batch_size=cfg.SOLVER.BATCH_SIZE,
                drop_last=True)

            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER)

            if cfg.MODEL.SELF_SUPERVISED:
                ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
                train_loader = DataLoader(ss_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            else:
                train_loader = DataLoader(train_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            dataloaders.append(train_loader)

        if args.eval_mode == "val":
            if args.return_best:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args, val_set_dict)
            else:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args)
        else:
            return do_train(cfg, model, dataloaders, optimizer, scheduler,
                            device, args)

    # No DG:
    if args.eval_mode == "val":
        train_dataset, val_dataset = build_dataset(
            dataset_list=cfg.DATASETS.TRAIN,
            transform=train_transform,
            target_transform=target_transform,
            split=True)
    else:
        train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                      transform=train_transform,
                                      target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)

    if cfg.MODEL.SELF_SUPERVISED:
        ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
        train_loader = DataLoader(ss_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)
    else:
        train_loader = DataLoader(train_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)

    if args.eval_mode == "val":
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args, {"validation_split": val_dataset})
    else:
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args)