def main():
    dir_weight = os.path.join(dir_save, 'weight')
    dir_log = os.path.join(dir_save, 'log')
    os.makedirs(dir_weight, exist_ok=True)
    writer = SummaryWriter(dir_log)

    indexes = [
        int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)
    ]
    current_step = max(indexes) if indexes else 0

    image_size = 768
    lr = 1e-3
    batch_size = 12
    num_workers = 4

    max_step = 250000
    lr_cfg = [[100000, lr], [200000, lr / 10], [max_step, lr / 50]]
    warm_up = [1000, lr / 50, lr]
    save_interval = 1000

    aug = Compose([
        ops.ToFloat(),
        ops.PhotometricDistort(),
        ops.RandomHFlip(),
        ops.RandomVFlip(),
        ops.RandomRotate90(),
        ops.ResizeJitter([0.8, 1.2]),
        ops.PadSquare(),
        ops.Resize(image_size),
        ops.BBoxFilter(24 * 24 * 0.4)
    ])
    dataset = DOTA(dir_dataset, ['train', 'val'], aug)
    loader = DataLoader(dataset,
                        batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True,
                        collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1, 2, 4, 8]] * 5,
        'scales': [[2**0, 2**(1 / 3), 2**(2 / 3)]] * 5,
    }

    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
    }

    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    if current_step:
        model.restore(os.path.join(dir_weight, '%d.pth' % current_step))
    else:
        model.init()
    if len(device_ids) > 1:
        model = convert_model(model)
        model = CustomDetDataParallel(model, device_ids)
    model.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    training = True
    while training and current_step < max_step:
        tqdm_loader = tqdm.tqdm(loader)
        for images, targets, infos in tqdm_loader:
            current_step += 1
            adjust_lr_multi_step(optimizer, current_step, lr_cfg, warm_up)

            images = images.cuda() / 255
            losses = model(images, targets)
            loss = sum(losses.values())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            for key, val in list(losses.items()):
                losses[key] = val.item()
                writer.add_scalar(key, val, global_step=current_step)
            writer.flush()
            tqdm_loader.set_postfix(losses)
            tqdm_loader.set_description(f'<{current_step}/{max_step}>')

            if current_step % save_interval == 0:
                save_path = os.path.join(dir_weight, '%d.pth' % current_step)
                state_dict = model.state_dict() if len(
                    device_ids) == 1 else model.module.state_dict()
                torch.save(state_dict, save_path)
                cache_file = os.path.join(
                    dir_weight, '%d.pth' % (current_step - save_interval))
                if os.path.exists(cache_file):
                    os.remove(cache_file)

            if current_step >= max_step:
                training = False
                writer.close()
                break
def main(batch_size, rank, world_size):

    import os
    import tqdm
    import torch
    import tempfile

    from torch import optim
    from torch import distributed as dist
    from torch.nn import SyncBatchNorm
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter

    from data.aug.compose import Compose
    from data.aug import ops
    from data.dataset import HRSC2016

    from model.rdd import RDD
    from model.backbone import resnet

    from utils.adjust_lr import adjust_lr_multi_step

    torch.manual_seed(0)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl",
                            init_method='env://',
                            rank=rank,
                            world_size=world_size)

    backbone = resnet.resnet101

    dir_dataset = '<replace with your local path>'
    dir_save = '<replace with your local path>'

    dir_weight = os.path.join(dir_save, 'weight')
    dir_log = os.path.join(dir_save, 'log')
    os.makedirs(dir_weight, exist_ok=True)
    if rank == 0:
        writer = SummaryWriter(dir_log)

    indexes = [
        int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)
    ]
    current_step = max(indexes) if indexes else 0

    image_size = 768
    lr = 1e-3
    batch_size //= world_size
    num_workers = 4

    max_step = 12000
    lr_cfg = [[7500, lr], [max_step, lr / 10]]
    warm_up = [500, lr / 50, lr]
    save_interval = 1000

    aug = Compose([
        ops.ToFloat(),
        ops.PhotometricDistort(),
        ops.RandomHFlip(),
        ops.RandomVFlip(),
        ops.RandomRotate90(),
        ops.ResizeJitter([0.8, 1.2]),
        ops.PadSquare(),
        ops.Resize(image_size),
    ])
    dataset = HRSC2016(dir_dataset, ['trainval'], aug)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, world_size, rank)
    batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                  batch_size,
                                                  drop_last=True)
    loader = DataLoader(dataset,
                        batch_sampler=batch_sampler,
                        num_workers=num_workers,
                        collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1.5, 3, 5, 8]] * 5,
        'scales': [[2**0, 2**(1 / 3), 2**(2 / 3)]] * 5,
    }

    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
    }
    device = torch.device(f'cuda:{rank}')
    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    model = SyncBatchNorm.convert_sync_batchnorm(model)
    model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    if current_step:
        model.module.load_state_dict(
            torch.load(os.path.join(dir_weight, '%d.pth' % current_step),
                       map_location=device))
    else:
        checkpoint = os.path.join(tempfile.gettempdir(), "initial-weights.pth")
        if rank == 0:
            model.module.init()
            torch.save(model.module.state_dict(), checkpoint)
        dist.barrier()
        if rank > 0:
            model.module.load_state_dict(
                torch.load(checkpoint, map_location=device))
        dist.barrier()
        if rank == 0:
            os.remove(checkpoint)

    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    training = True
    while training and current_step < max_step:
        tqdm_loader = tqdm.tqdm(loader) if rank == 0 else loader
        for images, targets, infos in tqdm_loader:
            current_step += 1
            adjust_lr_multi_step(optimizer, current_step, lr_cfg, warm_up)

            images = images.cuda() / 255
            losses = model(images, targets)
            loss = sum(losses.values())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if rank == 0:
                for key, val in list(losses.items()):
                    losses[key] = val.item()
                    writer.add_scalar(key, val, global_step=current_step)
                writer.flush()
                tqdm_loader.set_postfix(losses)
                tqdm_loader.set_description(f'<{current_step}/{max_step}>')

                if current_step % save_interval == 0:
                    save_path = os.path.join(dir_weight,
                                             '%d.pth' % current_step)
                    state_dict = model.module.state_dict()
                    torch.save(state_dict, save_path)
                    cache_file = os.path.join(
                        dir_weight, '%d.pth' % (current_step - save_interval))
                    if os.path.exists(cache_file):
                        os.remove(cache_file)

            if current_step >= max_step:
                training = False
                if rank == 0:
                    writer.close()
                break