def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=6, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='classifier_')
    arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--crops-dir', type=str, default='crops')
    arg('--label-smoothing', type=float, default=0.01)
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=0)
    arg("--local_rank", default=0, type=int)
    arg("--seed", default=777, type=int)
    arg("--padding-part", default=3, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--test_every", type=int, default=1)
    arg("--no-oversample", action="store_true")
    arg("--no-hardcore", action="store_true")
    arg("--only-changed-frames", action="store_true")

    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        # Initializes the default distributed process group
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        # Recursively traverse module and its children to replace all instances of
        # batch norms with sync batch norm
        # a Synchronized Batch Normalization (SyncBN) is a type of batch normalization used for multi-GPU training.
        # Standard batch normalization only normalizes the data within each device (GPU).
        # SyncBN normalizes the input within the whole mini-batch.
        model = convert_syncbn_model(model)
    ohem = conf.get("ohem_samples", None)
    reduction = "mean"
    if ohem:
        reduction = "none"
    loss_fn = []
    weights = []
    for loss_name, weight in conf["losses"].items():
        loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
        weights.append(weight)
    # computes the cumulative weighted loss
    loss = WeightedLosses(loss_fn, weights)
    loss_functions = {"classifier_loss": loss}
    # Creates optimizer and schedule from configuration
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)
    bce_best = 100
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']
    # this class is used to augment the exisiting train and val data in order to get
    # the most out of the available images and to create a neural network that doesn't memorize
    # the images that are fed into it
    data_train = DeepFakeClassifierDataset(
        mode="train",
        oversample_real=not args.no_oversample,
        fold=args.fold,
        padding_part=args.padding_part,
        hardcore=not args.no_hardcore,
        crops_dir=args.crops_dir,
        data_path=args.data_dir,
        label_smoothing=args.label_smoothing,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf["size"]),
        normalize=conf.get("normalize", None))
    data_val = DeepFakeClassifierDataset(mode="val",
                                         fold=args.fold,
                                         padding_part=args.padding_part,
                                         crops_dir=args.crops_dir,
                                         data_path=args.data_dir,
                                         folds_csv=args.folds_csv,
                                         transforms=create_val_transforms(
                                             conf["size"]),
                                         normalize=conf.get("normalize", None))
    val_data_loader = DataLoader(data_val,
                                 batch_size=batch_size * 2,
                                 num_workers=args.workers,
                                 shuffle=False,
                                 pin_memory=False)
    os.makedirs(args.logdir, exist_ok=True)
    # The SummaryWriter class creates an event file in a given directory and add summaries and events to it.
    # The class updates the file contents asynchronously.
    summary_writer = SummaryWriter(args.logdir + '/' +
                                   conf.get("prefix", args.prefix) +
                                   conf['encoder'] + "_" + str(args.fold))
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    bce_best = checkpoint.get('bce_best', 0)
            print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format(
                args.resume, checkpoint['epoch'], checkpoint['bce_best']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        # Allow Amp to perform casts as required by the opt_level
        # Commonly-used default modes are chosen by selecting an “optimization level” or opt_level;
        # each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed precision training.
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(conf.get("prefix",
                                                 args.prefix), conf['network'],
                                        conf['encoder'], args.fold)
    # the difference between DistributedDataParallel and DataParallel
    # DataParallel is single-process, multi-thread, and only works on a single machine,
    # while DistributedDataParallel is multi-process and works for both single- and multi- machine training.
    # DataParallel is usually slower than DistributedDataParallel
    if args.distributed:
        #this enables multiprocess distributed data parallel training
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    data_val.reset(1, args.seed)
    max_epochs = conf['optimizer']['schedule']['epochs']
    for epoch in range(start_epoch, max_epochs):
        data_train.reset(epoch, args.seed)
        train_sampler = None
        if args.distributed:
            # this restricts data loading to a subset of the dataset
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                data_train)
            train_sampler.set_epoch(epoch)
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            model.module.encoder.eval()
            for p in model.module.encoder.parameters():
                p.requires_grad = False
        else:
            model.module.encoder.train()
            for p in model.module.encoder.parameters():
                p.requires_grad = True
        # loads the training data
        train_data_loader = DataLoader(data_train,
                                       batch_size=batch_size,
                                       num_workers=args.workers,
                                       shuffle=train_sampler is None,
                                       sampler=train_sampler,
                                       pin_memory=False,
                                       drop_last=True)
        # trains the epoch with the training data
        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                    train_data_loader, summary_writer, conf, args.local_rank,
                    args.only_changed_frames)
        model = model.eval()
        # saves the epochs' models
        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                },
                args.output_dir + snapshot_name + "_{}".format(current_epoch))
            if (epoch + 1) % args.test_every == 0:
                bce_best = evaluate_val(args,
                                        val_data_loader,
                                        bce_best,
                                        model,
                                        snapshot_name=snapshot_name,
                                        current_epoch=current_epoch,
                                        summary_writer=summary_writer)
        current_epoch += 1
def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=6, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='classifier_')
    arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
    arg('--val-dir', type=str, default="../dfdc_train_all/dfdc_test")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--val-folds-csv', type=str)
    arg('--crops-dir', type=str, default='crops')
    arg('--label-smoothing', type=float, default=0.01)
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=0)
    arg("--local_rank", default=0, type=int)
    arg("--seed", default=777, type=int)
    arg("--padding-part", default=3, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--test_every", type=int, default=1)
    arg("--no-oversample", action="store_true")
    arg("--no-hardcore", action="store_true")
    arg("--only-changed-frames", action="store_true")

    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    ohem = conf.get("ohem_samples", None)
    reduction = "mean"
    if ohem:
        reduction = "none"
    loss_fn = []
    weights = []
    for loss_name, weight in conf["losses"].items():
        loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
        weights.append(weight)
    loss = WeightedLosses(loss_fn, weights)
    loss_functions = {"classifier_loss": loss}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)
    bce_best = 100
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']
    print("Config Loaded")
    data_train = DeepFakeClassifierDataset(
        mode="train",
        oversample_real=not args.no_oversample,
        fold=args.fold,
        padding_part=args.padding_part,
        hardcore=not args.no_hardcore,
        crops_dir=args.crops_dir,
        data_path=args.data_dir,
        label_smoothing=args.label_smoothing,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf["size"]),
        normalize=conf.get("normalize", None))
    print("train data Loaded")
    data_val = DeepFakeClassifierDataset(mode="val",
                                         fold=args.fold,
                                         padding_part=args.padding_part,
                                         crops_dir=args.crops_dir,
                                         data_path=args.data_dir,
                                         folds_csv=args.folds_csv,
                                         transforms=create_val_transforms(
                                             conf["size"]),
                                         normalize=conf.get("normalize", None))
    print("val data Loaded")
    val_data_loader = DataLoader(data_val,
                                 batch_size=batch_size * 2,
                                 num_workers=args.workers,
                                 shuffle=False,
                                 pin_memory=False)
    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' +
                                   conf.get("prefix", args.prefix) +
                                   conf['encoder'] + "_" + str(args.fold))
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    bce_best = checkpoint.get('bce_best', 0)
            print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format(
                args.resume, checkpoint['epoch'], checkpoint['bce_best']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(conf.get("prefix",
                                                 args.prefix), conf['network'],
                                        conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    data_val.reset(1, args.seed)
    max_epochs = conf['optimizer']['schedule']['epochs']
    for epoch in range(start_epoch, max_epochs):
        data_train.reset(epoch, args.seed)
        train_sampler = None
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                data_train)
            train_sampler.set_epoch(epoch)
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            model.module.encoder.eval()
            for p in model.module.encoder.parameters():
                p.requires_grad = False
        else:
            model.module.encoder.train()
            for p in model.module.encoder.parameters():
                p.requires_grad = True

        train_data_loader = DataLoader(data_train,
                                       batch_size=batch_size,
                                       num_workers=args.workers,
                                       shuffle=train_sampler is None,
                                       sampler=train_sampler,
                                       pin_memory=False,
                                       drop_last=True)

        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                    train_data_loader, summary_writer, conf, args.local_rank,
                    args.only_changed_frames)
        model = model.eval()

        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'bce_best': bce_best,
                },
                args.output_dir + snapshot_name + "_{}".format(current_epoch))
            if (epoch + 1) % args.test_every == 0:
                bce_best = evaluate_val(args,
                                        val_data_loader,
                                        bce_best,
                                        model,
                                        snapshot_name=snapshot_name,
                                        current_epoch=current_epoch,
                                        summary_writer=summary_writer)
        current_epoch += 1
logger.info('Create model and optimisers')
nclasses = len(conf["image_target_cols"]) + len(conf['exam_target_cols'])
logger.info(f'Nclasses : {nclasses}')
model = classifiers.__dict__[conf['network']](encoder=conf['encoder'],
                                              nclasses=nclasses)
model = model.to(args.device)

image_weight = conf['image_weight'] if 'image_weight' in conf else 1.
logger.info(f'Image BCE weight :{image_weight}')
bce_wts = torch.tensor([image_weight] + conf['exam_weights']).to(args.device)
logger.info(f"All BCE weights :{[image_weight] + conf['exam_weights']}")

criterion = torch.nn.BCEWithLogitsLoss(reduction='mean', weight=bce_wts)

optimizer, scheduler = create_optimizer(conf['optimizer'], model)
bce_best = 100
start_epoch = 0
batch_size = conf['optimizer']['batch_size']

os.makedirs(args.logdir, exist_ok=True)
summary_writer = SummaryWriter(args.logdir + '/' +
                               conf.get("prefix", args.prefix) +
                               conf['encoder'] + "_" + str(args.fold))

if args.from_zero:
    start_epoch = 0
current_epoch = start_epoch

if conf['fp16'] and args.device != 'cpu':
    scaler = torch.cuda.amp.GradScaler()