def main(): parser = argparse.ArgumentParser("PyTorch Xview Pipeline") arg = parser.add_argument arg('--config', metavar='CONFIG_FILE', help='path to configuration file') arg('--workers', type=int, default=6, help='number of cpu threads to use') arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') arg('--output-dir', type=str, default='weights/') arg('--resume', type=str, default='') arg('--fold', type=int, default=0) arg('--prefix', type=str, default='classifier_') arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake") arg('--folds-csv', type=str, default='folds.csv') arg('--crops-dir', type=str, default='crops') arg('--label-smoothing', type=float, default=0.01) arg('--logdir', type=str, default='logs') arg('--zero-score', action='store_true', default=False) arg('--from-zero', action='store_true', default=False) arg('--distributed', action='store_true', default=False) arg('--freeze-epochs', type=int, default=0) arg("--local_rank", default=0, type=int) arg("--seed", default=777, type=int) arg("--padding-part", default=3, type=int) arg("--opt-level", default='O1', type=str) arg("--test_every", type=int, default=1) arg("--no-oversample", action="store_true") arg("--no-hardcore", action="store_true") arg("--only-changed-frames", action="store_true") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) if args.distributed: torch.cuda.set_device(args.local_rank) # Initializes the default distributed process group torch.distributed.init_process_group(backend='nccl', init_method='env://') else: os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cudnn.benchmark = True conf = load_config(args.config) model = classifiers.__dict__[conf['network']](encoder=conf['encoder']) model = model.cuda() if args.distributed: # Recursively traverse module and its children to replace all instances of # batch norms with sync batch norm # a Synchronized Batch Normalization (SyncBN) is a type of batch normalization used for multi-GPU training. # Standard batch normalization only normalizes the data within each device (GPU). # SyncBN normalizes the input within the whole mini-batch. model = convert_syncbn_model(model) ohem = conf.get("ohem_samples", None) reduction = "mean" if ohem: reduction = "none" loss_fn = [] weights = [] for loss_name, weight in conf["losses"].items(): loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda()) weights.append(weight) # computes the cumulative weighted loss loss = WeightedLosses(loss_fn, weights) loss_functions = {"classifier_loss": loss} # Creates optimizer and schedule from configuration optimizer, scheduler = create_optimizer(conf['optimizer'], model) bce_best = 100 start_epoch = 0 batch_size = conf['optimizer']['batch_size'] # this class is used to augment the exisiting train and val data in order to get # the most out of the available images and to create a neural network that doesn't memorize # the images that are fed into it data_train = DeepFakeClassifierDataset( mode="train", oversample_real=not args.no_oversample, fold=args.fold, padding_part=args.padding_part, hardcore=not args.no_hardcore, crops_dir=args.crops_dir, data_path=args.data_dir, label_smoothing=args.label_smoothing, folds_csv=args.folds_csv, transforms=create_train_transforms(conf["size"]), normalize=conf.get("normalize", None)) data_val = DeepFakeClassifierDataset(mode="val", fold=args.fold, padding_part=args.padding_part, crops_dir=args.crops_dir, data_path=args.data_dir, folds_csv=args.folds_csv, transforms=create_val_transforms( conf["size"]), normalize=conf.get("normalize", None)) val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False, pin_memory=False) os.makedirs(args.logdir, exist_ok=True) # The SummaryWriter class creates an event file in a given directory and add summaries and events to it. # The class updates the file contents asynchronously. summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold)) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') state_dict = checkpoint['state_dict'] state_dict = {k[7:]: w for k, w in state_dict.items()} model.load_state_dict(state_dict, strict=False) if not args.from_zero: start_epoch = checkpoint['epoch'] if not args.zero_score: bce_best = checkpoint.get('bce_best', 0) print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format( args.resume, checkpoint['epoch'], checkpoint['bce_best'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.from_zero: start_epoch = 0 current_epoch = start_epoch if conf['fp16']: # Allow Amp to perform casts as required by the opt_level # Commonly-used default modes are chosen by selecting an “optimization level” or opt_level; # each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed precision training. model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale='dynamic') snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold) # the difference between DistributedDataParallel and DataParallel # DataParallel is single-process, multi-thread, and only works on a single machine, # while DistributedDataParallel is multi-process and works for both single- and multi- machine training. # DataParallel is usually slower than DistributedDataParallel if args.distributed: #this enables multiprocess distributed data parallel training model = DistributedDataParallel(model, delay_allreduce=True) else: model = DataParallel(model).cuda() data_val.reset(1, args.seed) max_epochs = conf['optimizer']['schedule']['epochs'] for epoch in range(start_epoch, max_epochs): data_train.reset(epoch, args.seed) train_sampler = None if args.distributed: # this restricts data loading to a subset of the dataset train_sampler = torch.utils.data.distributed.DistributedSampler( data_train) train_sampler.set_epoch(epoch) if epoch < args.freeze_epochs: print("Freezing encoder!!!") model.module.encoder.eval() for p in model.module.encoder.parameters(): p.requires_grad = False else: model.module.encoder.train() for p in model.module.encoder.parameters(): p.requires_grad = True # loads the training data train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers, shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, drop_last=True) # trains the epoch with the training data train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, args.local_rank, args.only_changed_frames) model = model.eval() # saves the epochs' models if args.local_rank == 0: torch.save( { 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + '/' + snapshot_name + "_last") torch.save( { 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + snapshot_name + "_{}".format(current_epoch)) if (epoch + 1) % args.test_every == 0: bce_best = evaluate_val(args, val_data_loader, bce_best, model, snapshot_name=snapshot_name, current_epoch=current_epoch, summary_writer=summary_writer) current_epoch += 1
def main(): parser = argparse.ArgumentParser("PyTorch Xview Pipeline") arg = parser.add_argument arg('--config', metavar='CONFIG_FILE', help='path to configuration file') arg('--workers', type=int, default=6, help='number of cpu threads to use') arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') arg('--output-dir', type=str, default='weights/') arg('--resume', type=str, default='') arg('--fold', type=int, default=0) arg('--prefix', type=str, default='classifier_') arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake") arg('--val-dir', type=str, default="../dfdc_train_all/dfdc_test") arg('--folds-csv', type=str, default='folds.csv') arg('--val-folds-csv', type=str) arg('--crops-dir', type=str, default='crops') arg('--label-smoothing', type=float, default=0.01) arg('--logdir', type=str, default='logs') arg('--zero-score', action='store_true', default=False) arg('--from-zero', action='store_true', default=False) arg('--distributed', action='store_true', default=False) arg('--freeze-epochs', type=int, default=0) arg("--local_rank", default=0, type=int) arg("--seed", default=777, type=int) arg("--padding-part", default=3, type=int) arg("--opt-level", default='O1', type=str) arg("--test_every", type=int, default=1) arg("--no-oversample", action="store_true") arg("--no-hardcore", action="store_true") arg("--only-changed-frames", action="store_true") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') else: os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cudnn.benchmark = True conf = load_config(args.config) model = classifiers.__dict__[conf['network']](encoder=conf['encoder']) model = model.cuda() if args.distributed: model = convert_syncbn_model(model) ohem = conf.get("ohem_samples", None) reduction = "mean" if ohem: reduction = "none" loss_fn = [] weights = [] for loss_name, weight in conf["losses"].items(): loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda()) weights.append(weight) loss = WeightedLosses(loss_fn, weights) loss_functions = {"classifier_loss": loss} optimizer, scheduler = create_optimizer(conf['optimizer'], model) bce_best = 100 start_epoch = 0 batch_size = conf['optimizer']['batch_size'] print("Config Loaded") data_train = DeepFakeClassifierDataset( mode="train", oversample_real=not args.no_oversample, fold=args.fold, padding_part=args.padding_part, hardcore=not args.no_hardcore, crops_dir=args.crops_dir, data_path=args.data_dir, label_smoothing=args.label_smoothing, folds_csv=args.folds_csv, transforms=create_train_transforms(conf["size"]), normalize=conf.get("normalize", None)) print("train data Loaded") data_val = DeepFakeClassifierDataset(mode="val", fold=args.fold, padding_part=args.padding_part, crops_dir=args.crops_dir, data_path=args.data_dir, folds_csv=args.folds_csv, transforms=create_val_transforms( conf["size"]), normalize=conf.get("normalize", None)) print("val data Loaded") val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False, pin_memory=False) os.makedirs(args.logdir, exist_ok=True) summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold)) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') state_dict = checkpoint['state_dict'] state_dict = {k[7:]: w for k, w in state_dict.items()} model.load_state_dict(state_dict, strict=False) if not args.from_zero: start_epoch = checkpoint['epoch'] if not args.zero_score: bce_best = checkpoint.get('bce_best', 0) print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format( args.resume, checkpoint['epoch'], checkpoint['bce_best'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.from_zero: start_epoch = 0 current_epoch = start_epoch if conf['fp16']: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale='dynamic') snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold) if args.distributed: model = DistributedDataParallel(model, delay_allreduce=True) else: model = DataParallel(model).cuda() data_val.reset(1, args.seed) max_epochs = conf['optimizer']['schedule']['epochs'] for epoch in range(start_epoch, max_epochs): data_train.reset(epoch, args.seed) train_sampler = None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( data_train) train_sampler.set_epoch(epoch) if epoch < args.freeze_epochs: print("Freezing encoder!!!") model.module.encoder.eval() for p in model.module.encoder.parameters(): p.requires_grad = False else: model.module.encoder.train() for p in model.module.encoder.parameters(): p.requires_grad = True train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers, shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, drop_last=True) train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, args.local_rank, args.only_changed_frames) model = model.eval() if args.local_rank == 0: torch.save( { 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + '/' + snapshot_name + "_last") torch.save( { 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + snapshot_name + "_{}".format(current_epoch)) if (epoch + 1) % args.test_every == 0: bce_best = evaluate_val(args, val_data_loader, bce_best, model, snapshot_name=snapshot_name, current_epoch=current_epoch, summary_writer=summary_writer) current_epoch += 1
logger.info('Create model and optimisers') nclasses = len(conf["image_target_cols"]) + len(conf['exam_target_cols']) logger.info(f'Nclasses : {nclasses}') model = classifiers.__dict__[conf['network']](encoder=conf['encoder'], nclasses=nclasses) model = model.to(args.device) image_weight = conf['image_weight'] if 'image_weight' in conf else 1. logger.info(f'Image BCE weight :{image_weight}') bce_wts = torch.tensor([image_weight] + conf['exam_weights']).to(args.device) logger.info(f"All BCE weights :{[image_weight] + conf['exam_weights']}") criterion = torch.nn.BCEWithLogitsLoss(reduction='mean', weight=bce_wts) optimizer, scheduler = create_optimizer(conf['optimizer'], model) bce_best = 100 start_epoch = 0 batch_size = conf['optimizer']['batch_size'] os.makedirs(args.logdir, exist_ok=True) summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold)) if args.from_zero: start_epoch = 0 current_epoch = start_epoch if conf['fp16'] and args.device != 'cpu': scaler = torch.cuda.amp.GradScaler()