def create_dataset_loader(input_folder):
    def create_dataset_from_folder(dir_dataset):
        pairs = []
        for filename in os.listdir(dir_dataset):
            if filename.endswith(('.jpg', '.png')):
                img = os.path.join(dir_dataset, filename)
                anno = None
                pairs.append([img, anno])

        dataset_path = os.path.join(FLAGS.output_folder, 'image-sets',
                                    'dataset.json')
        os.makedirs(Path(dataset_path).parent, exist_ok=True)
        json.dump(pairs, open(dataset_path, 'wt'), indent=2)

    create_dataset_from_folder(FLAGS.input_folder)
    torch.cuda.set_device(int(FLAGS.GPUS.split(',')[0]))

    aug = Compose([ops.PadSquare(), ops.Resize(FLAGS.image_size)])
    dataset = DOTA(FLAGS.output_folder, 'dataset', aug)

    loader = DataLoader(dataset,
                        FLAGS.batch_size,
                        num_workers=FLAGS.num_workers,
                        pin_memory=True,
                        collate_fn=dataset.collate)
    print(f'created dataset from  {len(dataset)} files')
    return loader, len(dataset.names)
def main():
    dir_weight = os.path.join(dir_save, 'weight')
    dir_log = os.path.join(dir_save, 'log')
    os.makedirs(dir_weight, exist_ok=True)
    writer = SummaryWriter(dir_log)

    indexes = [
        int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)
    ]
    current_step = max(indexes) if indexes else 0

    image_size = 768
    lr = 1e-3
    batch_size = 12
    num_workers = 4

    max_step = 250000
    lr_cfg = [[100000, lr], [200000, lr / 10], [max_step, lr / 50]]
    warm_up = [1000, lr / 50, lr]
    save_interval = 1000

    aug = Compose([
        ops.ToFloat(),
        ops.PhotometricDistort(),
        ops.RandomHFlip(),
        ops.RandomVFlip(),
        ops.RandomRotate90(),
        ops.ResizeJitter([0.8, 1.2]),
        ops.PadSquare(),
        ops.Resize(image_size),
        ops.BBoxFilter(24 * 24 * 0.4)
    ])
    dataset = DOTA(dir_dataset, ['train', 'val'], aug)
    loader = DataLoader(dataset,
                        batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True,
                        collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1, 2, 4, 8]] * 5,
        'scales': [[2**0, 2**(1 / 3), 2**(2 / 3)]] * 5,
    }

    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
    }

    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    if current_step:
        model.restore(os.path.join(dir_weight, '%d.pth' % current_step))
    else:
        model.init()
    if len(device_ids) > 1:
        model = convert_model(model)
        model = CustomDetDataParallel(model, device_ids)
    model.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    training = True
    while training and current_step < max_step:
        tqdm_loader = tqdm.tqdm(loader)
        for images, targets, infos in tqdm_loader:
            current_step += 1
            adjust_lr_multi_step(optimizer, current_step, lr_cfg, warm_up)

            images = images.cuda() / 255
            losses = model(images, targets)
            loss = sum(losses.values())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            for key, val in list(losses.items()):
                losses[key] = val.item()
                writer.add_scalar(key, val, global_step=current_step)
            writer.flush()
            tqdm_loader.set_postfix(losses)
            tqdm_loader.set_description(f'<{current_step}/{max_step}>')

            if current_step % save_interval == 0:
                save_path = os.path.join(dir_weight, '%d.pth' % current_step)
                state_dict = model.state_dict() if len(
                    device_ids) == 1 else model.module.state_dict()
                torch.save(state_dict, save_path)
                cache_file = os.path.join(
                    dir_weight, '%d.pth' % (current_step - save_interval))
                if os.path.exists(cache_file):
                    os.remove(cache_file)

            if current_step >= max_step:
                training = False
                writer.close()
                break
def main():
    global checkpoint
    if checkpoint is None:
        dir_weight = os.path.join(dir_save, 'weight')
        indexes = [
            int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)
        ]
        current_step = max(indexes)
        checkpoint = os.path.join(dir_weight, '%d.pth' % current_step)

    batch_size = 32
    num_workers = 4

    image_size = 768
    aug = Compose([ops.PadSquare(), ops.Resize(image_size)])
    dataset = DOTA(dir_dataset, image_set, aug)
    loader = DataLoader(dataset,
                        batch_size,
                        num_workers=num_workers,
                        pin_memory=True,
                        collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1, 2, 4, 8]] * 5,
        'scales': [[2**0, 2**(1 / 3), 2**(2 / 3)]] * 5,
        'old_version': old_version
    }
    conf_thresh = 0.01
    nms_thresh = 0.45
    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
        'conf_thresh': conf_thresh,
        'nms_thresh': nms_thresh,
    }

    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    model.restore(checkpoint)
    if len(device_ids) > 1:
        model = CustomDetDataParallel(model, device_ids)
    model.cuda()
    model.eval()

    ret_raw = defaultdict(list)
    for images, targets, infos in tqdm.tqdm(loader):
        images = images.cuda() / 255
        dets = model(images)
        for (det, info) in zip(dets, infos):
            if det:
                bboxes, scores, labels = det
                bboxes = bboxes.cpu().numpy()
                scores = scores.cpu().numpy()
                labels = labels.cpu().numpy()
                fname, x, y, w, h = os.path.splitext(
                    os.path.basename(info['img_path']))[0].split('-')[:5]
                x, y, w, h = int(x), int(y), int(w), int(h)
                long_edge = max(w, h)
                pad_x, pad_y = (long_edge - w) // 2, (long_edge - h) // 2
                bboxes = np.stack([xywha2xy4(bbox) for bbox in bboxes])
                bboxes *= long_edge / image_size
                bboxes -= [pad_x, pad_y]
                bboxes += [x, y]
                bboxes = np.stack([xy42xywha(bbox) for bbox in bboxes])
                ret_raw[fname].append([bboxes, scores, labels])

    print('merging results...')
    ret = []

    for fname, dets in ret_raw.items():
        bboxes, scores, labels = zip(*dets)
        bboxes = np.concatenate(list(bboxes))
        scores = np.concatenate(list(scores))
        labels = np.concatenate(list(labels))
        keeps = rbbox_batched_nms(bboxes, scores, labels, nms_thresh)
        ret.append([fname, [bboxes[keeps], scores[keeps], labels[keeps]]])

    print('converting to submission format...')
    ret_save = defaultdict(list)
    for fname, (bboxes, scores, labels) in ret:
        for bbox, score, label in zip(bboxes, scores, labels):
            bbox = xywha2xy4(bbox).ravel()
            line = '%s %.12f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f' % (
                fname, score, *bbox)
            ret_save[dataset.label2name[label]].append(line)

    print('saving...')
    os.makedirs(os.path.join(dir_save, 'submission'), exist_ok=True)
    for name, dets in ret_save.items():
        with open(
                os.path.join(dir_save, 'submission',
                             'Task%d_%s.txt' % (1, name)), 'wt') as f:
            f.write('\n'.join(dets))

    print('finished')
def main():
    global checkpoint
    if checkpoint is None:
        dir_weight = os.path.join(dir_save, 'weight')
        indexes = [int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)]
        current_step = max(indexes)
        checkpoint = os.path.join(dir_weight, '%d.pth' % current_step)

    image_size = 768
    batch_size = 32
    num_workers = 4

    aug = ops.Resize(image_size)
    dataset = HRSC2016(dir_dataset, 'test', aug)
    loader = DataLoader(dataset, batch_size, num_workers=num_workers, pin_memory=True, collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1.5, 3, 5, 8]] * 5,
        'scales': [[2 ** 0, 2 ** (1 / 3), 2 ** (2 / 3)]] * 5,
        'old_version': old_version
    }
    conf_thresh = 0.01
    nms_thresh = 0.45
    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
        'conf_thresh': conf_thresh,
        'nms_thresh': nms_thresh,
    }

    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    model.restore(checkpoint)
    if len(device_ids) > 1:
        model = CustomDetDataParallel(model, device_ids)
    model.cuda()
    model.eval()

    count = 0
    gt_list, det_list = [], []
    for images, targets, infos in tqdm.tqdm(loader):
        images = images.cuda() / 255
        dets = model(images)
        for target, det, info in zip(targets, dets, infos):
            if target:
                bboxes = np.stack([xy42xywha(bbox) for bbox in info['objs']['bboxes']])
                labels = info['objs']['labels']
                gt_list.extend([count, bbox, 1, label] for bbox, label in zip(bboxes, labels))
            if det:
                ih, iw = info['shape'][:2]
                bboxes, scores, labels = list(map(lambda x: x.cpu().numpy(), det))
                bboxes = np.stack([xywha2xy4(bbox) for bbox in bboxes])
                bboxes_ = bboxes * [iw / image_size, ih / image_size]
                # bboxes = np.stack([xy42xywha(bbox) for bbox in bboxes_])
                bboxes = []
                for bbox in bboxes_.astype(np.float32):
                    (x, y), (w, h), a = cv.minAreaRect(bbox)
                    bboxes.append([x, y, w, h, a])
                bboxes = np.array(bboxes)
                det_list.extend([count, bbox, score, label] for bbox, score, label in zip(bboxes, scores, labels))
            count += 1
    APs = get_det_aps(det_list, gt_list, num_classes, use_07_metric=use_07_metric)
    mAP = sum(APs) / len(APs)
    print('AP')
    for label in range(num_classes):
        print(f'{dataset.label2name[label]}: {APs[label]}')
    print(f'mAP: {mAP}')
def main(batch_size, rank, world_size):

    import os
    import tqdm
    import torch
    import tempfile

    from torch import optim
    from torch import distributed as dist
    from torch.nn import SyncBatchNorm
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter

    from data.aug.compose import Compose
    from data.aug import ops
    from data.dataset import HRSC2016

    from model.rdd import RDD
    from model.backbone import resnet

    from utils.adjust_lr import adjust_lr_multi_step

    torch.manual_seed(0)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl",
                            init_method='env://',
                            rank=rank,
                            world_size=world_size)

    backbone = resnet.resnet101

    dir_dataset = '<replace with your local path>'
    dir_save = '<replace with your local path>'

    dir_weight = os.path.join(dir_save, 'weight')
    dir_log = os.path.join(dir_save, 'log')
    os.makedirs(dir_weight, exist_ok=True)
    if rank == 0:
        writer = SummaryWriter(dir_log)

    indexes = [
        int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)
    ]
    current_step = max(indexes) if indexes else 0

    image_size = 768
    lr = 1e-3
    batch_size //= world_size
    num_workers = 4

    max_step = 12000
    lr_cfg = [[7500, lr], [max_step, lr / 10]]
    warm_up = [500, lr / 50, lr]
    save_interval = 1000

    aug = Compose([
        ops.ToFloat(),
        ops.PhotometricDistort(),
        ops.RandomHFlip(),
        ops.RandomVFlip(),
        ops.RandomRotate90(),
        ops.ResizeJitter([0.8, 1.2]),
        ops.PadSquare(),
        ops.Resize(image_size),
    ])
    dataset = HRSC2016(dir_dataset, ['trainval'], aug)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, world_size, rank)
    batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                  batch_size,
                                                  drop_last=True)
    loader = DataLoader(dataset,
                        batch_sampler=batch_sampler,
                        num_workers=num_workers,
                        collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1.5, 3, 5, 8]] * 5,
        'scales': [[2**0, 2**(1 / 3), 2**(2 / 3)]] * 5,
    }

    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
    }
    device = torch.device(f'cuda:{rank}')
    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    model = SyncBatchNorm.convert_sync_batchnorm(model)
    model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    if current_step:
        model.module.load_state_dict(
            torch.load(os.path.join(dir_weight, '%d.pth' % current_step),
                       map_location=device))
    else:
        checkpoint = os.path.join(tempfile.gettempdir(), "initial-weights.pth")
        if rank == 0:
            model.module.init()
            torch.save(model.module.state_dict(), checkpoint)
        dist.barrier()
        if rank > 0:
            model.module.load_state_dict(
                torch.load(checkpoint, map_location=device))
        dist.barrier()
        if rank == 0:
            os.remove(checkpoint)

    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    training = True
    while training and current_step < max_step:
        tqdm_loader = tqdm.tqdm(loader) if rank == 0 else loader
        for images, targets, infos in tqdm_loader:
            current_step += 1
            adjust_lr_multi_step(optimizer, current_step, lr_cfg, warm_up)

            images = images.cuda() / 255
            losses = model(images, targets)
            loss = sum(losses.values())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if rank == 0:
                for key, val in list(losses.items()):
                    losses[key] = val.item()
                    writer.add_scalar(key, val, global_step=current_step)
                writer.flush()
                tqdm_loader.set_postfix(losses)
                tqdm_loader.set_description(f'<{current_step}/{max_step}>')

                if current_step % save_interval == 0:
                    save_path = os.path.join(dir_weight,
                                             '%d.pth' % current_step)
                    state_dict = model.module.state_dict()
                    torch.save(state_dict, save_path)
                    cache_file = os.path.join(
                        dir_weight, '%d.pth' % (current_step - save_interval))
                    if os.path.exists(cache_file):
                        os.remove(cache_file)

            if current_step >= max_step:
                training = False
                if rank == 0:
                    writer.close()
                break