def main(): # Args parser = argparse.ArgumentParser() parser.add_argument('--testsets', type=str, help='Testing datasets', nargs='+', choices=split.available_datasets, required=True) parser.add_argument('--testsplits', type=str, help='Test split', nargs='+', default=['val', 'test'], choices=['train', 'val', 'test']) # Alternative 1: Specify training params parser.add_argument('--net', type=str, help='Net model class') parser.add_argument('--traindb', type=str, action='store', help='Dataset used for training') parser.add_argument('--face', type=str, help='Face crop or scale', default='scale', choices=['scale', 'tight']) parser.add_argument('--size', type=int, help='Train patch size') weights_group = parser.add_mutually_exclusive_group(required=True) weights_group.add_argument('--weights', type=Path, help='Weight filename', default='bestval.pth') # Alternative 2: Specify trained model path weights_group.add_argument('--model_path', type=Path, help='Full path of the trained model') # Common params parser.add_argument('--batch', type=int, help='Batch size to fit in GPU memory', default=128) parser.add_argument('--workers', type=int, help='Num workers for data loaders', default=6) parser.add_argument('--device', type=int, help='GPU id', default=0) parser.add_argument('--seed', type=int, help='Random seed used for training', default=0) parser.add_argument('--debug', action='store_true', help='Debug flag', ) parser.add_argument('--suffix', type=str, help='Suffix to default tag') parser.add_argument('--models_dir', type=Path, help='Folder with trained models', default='weights/') parser.add_argument('--num_video', type=int, help='Number of real-fake videos to test') parser.add_argument('--results_dir', type=Path, help='Output folder', default='results/') parser.add_argument('--override', action='store_true', help='Override existing results', ) args = parser.parse_args() device = torch.device('cuda:{}'.format(args.device)) if torch.cuda.is_available() else torch.device('cpu') patch_size: int = args.size num_workers: int = args.workers batch_size: int = args.batch net_name: str = args.net weights: Path = args.weights suffix: str = args.suffix face_policy: str = args.face models_dir: Path = args.models_dir max_num_videos_per_label: int = args.num_video # number of real-fake videos to test model_path: Path = args.model_path results_dir: Path = args.results_dir debug: bool = args.debug override: bool = args.override train_datasets = args.traindb seed: int = args.seed test_sets = args.testsets test_splits = args.testsplits if model_path is None: if net_name is None: raise RuntimeError('Net name is required if \"model_path\" is not provided') model_name = utils.make_train_tag(net_class=getattr(fornet, net_name), traindb=train_datasets, face_policy=face_policy, patch_size=patch_size, seed=seed, suffix=suffix, debug=debug, ) model_path = models_dir.joinpath(model_name, weights) else: # get arguments from the model path face_policy = str(model_path).split('face-')[1].split('_')[0] patch_size = int(str(model_path).split('size-')[1].split('_')[0]) net_name = str(model_path).split('net-')[1].split('_')[0] model_name = '_'.join(model_path.with_suffix('').parts[-2:]) # Load net net_class = getattr(fornet, net_name) # load model print('Loading model...') state_tmp = torch.load(model_path, map_location='cpu') if 'net' not in state_tmp.keys(): state = OrderedDict({'net': OrderedDict()}) [state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()] else: state = state_tmp net: FeatureExtractor = net_class().eval().to(device) incomp_keys = net.load_state_dict(state['net'], strict=True) print(incomp_keys) print('Model loaded!') # val loss per-frame criterion = nn.BCEWithLogitsLoss(reduction='none') # Define data transformers test_transformer = utils.get_transformer(face_policy, patch_size, net.get_normalizer(), train=False) # datasets and dataloaders (from train_binclass.py) print('Loading data...') splits = split.make_splits(dbs={'train': test_sets, 'val': test_sets, 'test': test_sets}) train_dfs = [splits['train'][db][0] for db in splits['train']] train_roots = [splits['train'][db][1] for db in splits['train']] val_roots = [splits['val'][db][1] for db in splits['val']] val_dfs = [splits['val'][db][0] for db in splits['val']] test_dfs = [splits['test'][db][0] for db in splits['test']] test_roots = [splits['test'][db][1] for db in splits['test']] # Output paths out_folder = results_dir.joinpath(model_name) out_folder.mkdir(mode=0o775, parents=True, exist_ok=True) # Samples selection if max_num_videos_per_label is not None: dfs_out_train = [select_videos(df, max_num_videos_per_label) for df in train_dfs] dfs_out_val = [select_videos(df, max_num_videos_per_label) for df in val_dfs] dfs_out_test = [select_videos(df, max_num_videos_per_label) for df in test_dfs] else: dfs_out_train = train_dfs dfs_out_val = val_dfs dfs_out_test = test_dfs # Extractions list extr_list = [] # Append train and validation set first if 'train' in test_splits: for idx, dataset in enumerate(test_sets): extr_list.append( (dfs_out_train[idx], out_folder.joinpath(dataset + '_train.pkl'), train_roots[idx], dataset + ' TRAIN') ) if 'val' in test_splits: for idx, dataset in enumerate(test_sets): extr_list.append( (dfs_out_val[idx], out_folder.joinpath(dataset + '_val.pkl'), val_roots[idx], dataset + ' VAL') ) if 'test' in test_splits: for idx, dataset in enumerate(test_sets): extr_list.append( (dfs_out_test[idx], out_folder.joinpath(dataset + '_test.pkl'), test_roots[idx], dataset + ' TEST') ) for df, df_path, df_root, tag in extr_list: if override or not df_path.exists(): print('\n##### PREDICT VIDEOS FROM {} #####'.format(tag)) print('Real frames: {}'.format(sum(df['label'] == False))) print('Fake frames: {}'.format(sum(df['label'] == True))) print('Real videos: {}'.format(df[df['label'] == False]['video'].nunique())) print('Fake videos: {}'.format(df[df['label'] == True]['video'].nunique())) dataset_out = process_dataset(root=df_root, df=df, net=net, criterion=criterion, patch_size=patch_size, face_policy=face_policy, transformer=test_transformer, batch_size=batch_size, num_workers=num_workers, device=device, ) df['score'] = dataset_out['score'].astype(np.float32) df['loss'] = dataset_out['loss'].astype(np.float32) print('Saving results to: {}'.format(df_path)) df.to_pickle(str(df_path)) if debug: plt.figure() plt.title(tag) plt.hist(df[df.label == True].score, bins=100, alpha=0.6, label='FAKE frames') plt.hist(df[df.label == False].score, bins=100, alpha=0.6, label='REAL frames') plt.legend() del (dataset_out) del (df) gc.collect() if debug: plt.show() print('Completed!')
def main(): # Args parser = argparse.ArgumentParser() parser.add_argument('--net', type=str, help='Net model class', required=True) parser.add_argument('--traindb', type=str, help='Training datasets', nargs='+', choices=split.available_datasets, required=True) parser.add_argument('--valdb', type=str, help='Validation datasets', nargs='+', choices=split.available_datasets, required=True) parser.add_argument('--face', type=str, help='Face crop or scale', required=True, choices=['scale', 'tight']) parser.add_argument('--size', type=int, help='Train patch size', required=True) parser.add_argument('--batch', type=int, help='Batch size to fit in GPU memory', default=12) parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') parser.add_argument('--valint', type=int, help='Validation interval (iterations)', default=500) parser.add_argument( '--patience', type=int, help='Patience before dropping the LR [validation intervals]', default=10) parser.add_argument('--maxiter', type=int, help='Maximum number of iterations', default=20000) parser.add_argument('--init', type=str, help='Weight initialization file') parser.add_argument('--scratch', action='store_true', help='Train from scratch') parser.add_argument('--traintriplets', type=int, help='Limit the number of train triplets per epoch', default=-1) parser.add_argument( '--valtriplets', type=int, help='Limit the number of validation triplets per epoch', default=2000) parser.add_argument('--logint', type=int, help='Training log interval (iterations)', default=100) parser.add_argument('--workers', type=int, help='Num workers for data loaders', default=6) parser.add_argument('--device', type=int, help='GPU device id', default=0) parser.add_argument('--seed', type=int, help='Random seed', default=0) parser.add_argument('--debug', action='store_true', help='Activate debug') parser.add_argument('--suffix', type=str, help='Suffix to default tag') parser.add_argument('--attention', action='store_true', help='Enable Tensorboard log of attention masks') parser.add_argument('--embedding', action='store_true', help='Activate embedding visualization in TensorBoard') parser.add_argument('--embeddingint', type=int, help='Embedding visualization interval in TensorBoard', default=5000) parser.add_argument('--log_dir', type=str, help='Directory for saving the training logs', default='runs/triplet/') parser.add_argument('--models_dir', type=str, help='Directory for saving the models weights', default='weights/triplet/') args = parser.parse_args() # Parse arguments net_class = getattr(tripletnet, args.net) train_datasets = args.traindb val_datasets = args.valdb face_policy = args.face face_size = args.size batch_size = args.batch initial_lr = args.lr validation_interval = args.valint patience = args.patience max_num_iterations = args.maxiter initial_model = args.init train_from_scratch = args.scratch max_train_triplets = args.traintriplets max_val_triplets = args.valtriplets log_interval = args.logint num_workers = args.workers device = torch.device('cuda:{:d}'.format( args.device)) if torch.cuda.is_available() else torch.device('cpu') seed = args.seed debug = args.debug suffix = args.suffix enable_attention = args.attention enable_embedding = args.embedding embedding_interval = args.embeddingint weights_folder = args.models_dir logs_folder = args.log_dir # Random initialization np.random.seed(seed) torch.random.manual_seed(seed) # Load net net: nn.Module = net_class().to(device) # Loss and optimizers criterion = nn.TripletMarginLoss() min_lr = initial_lr * 1e-5 optimizer = optim.Adam(net.get_trainable_parameters(), lr=initial_lr) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode='min', factor=0.1, patience=patience, cooldown=2 * patience, min_lr=min_lr, ) tag = utils.make_train_tag( net_class=net_class, traindb=train_datasets, face_policy=face_policy, patch_size=face_size, seed=seed, suffix=suffix, debug=debug, ) # Model checkpoint paths bestval_path = os.path.join(weights_folder, tag, 'bestval.pth') last_path = os.path.join(weights_folder, tag, 'last.pth') periodic_path = os.path.join(weights_folder, tag, 'it{:06d}.pth') os.makedirs(os.path.join(weights_folder, tag), exist_ok=True) # Load model val_loss = min_val_loss = 20 epoch = iteration = 0 net_state = None opt_state = None if initial_model is not None: # If given load initial model print('Loading model form: {}'.format(initial_model)) state = torch.load(initial_model, map_location='cpu') net_state = state['net'] elif not train_from_scratch and os.path.exists(last_path): print('Loading model form: {}'.format(last_path)) state = torch.load(last_path, map_location='cpu') net_state = state['net'] opt_state = state['opt'] iteration = state['iteration'] + 1 epoch = state['epoch'] if not train_from_scratch and os.path.exists(bestval_path): state = torch.load(bestval_path, map_location='cpu') min_val_loss = state['val_loss'] if net_state is not None: adapt_binclass_model(net_state) incomp_keys = net.load_state_dict(net_state, strict=False) print(incomp_keys) if opt_state is not None: for param_group in opt_state['param_groups']: param_group['lr'] = initial_lr optimizer.load_state_dict(opt_state) # Initialize Tensorboard logdir = os.path.join(logs_folder, tag) if iteration == 0: # If training from scratch or initialization remove history if exists shutil.rmtree(logdir, ignore_errors=True) # TensorboardX instance tb = SummaryWriter(logdir=logdir) if iteration == 0: dummy = torch.randn((1, 3, face_size, face_size), device=device) with warnings.catch_warnings(): warnings.simplefilter("ignore") tb.add_graph(net, [dummy, dummy, dummy], verbose=False) transformer = utils.get_transformer(face_policy=face_policy, patch_size=face_size, net_normalizer=net.get_normalizer(), train=True) # Datasets and data loaders print('Loading data') splits = split.make_splits(dbs={ 'train': train_datasets, 'val': val_datasets }) train_dfs = [splits['train'][db][0] for db in splits['train']] train_roots = [splits['train'][db][1] for db in splits['train']] val_roots = [splits['val'][db][1] for db in splits['val']] val_dfs = [splits['val'][db][0] for db in splits['val']] train_dataset = FrameFaceTripletIterableDataset( roots=train_roots, dfs=train_dfs, scale=face_policy, num_triplets=max_train_triplets, transformer=transformer, size=face_size, ) val_dataset = FrameFaceTripletIterableDataset( roots=val_roots, dfs=val_dfs, scale=face_policy, num_triplets=max_val_triplets, transformer=transformer, size=face_size, ) train_loader = DataLoader( train_dataset, num_workers=num_workers, batch_size=batch_size, ) val_loader = DataLoader( val_dataset, num_workers=num_workers, batch_size=batch_size, ) print('Training triplets: {}'.format(len(train_dataset))) print('Validation triplets: {}'.format(len(val_dataset))) if len(train_dataset) == 0: print('No training triplets. Halt.') return if len(val_dataset) == 0: print('No validation triplets. Halt.') return # Embedding visualization if enable_embedding: train_dataset_embedding = FrameFaceIterableDataset( roots=train_roots, dfs=train_dfs, scale=face_policy, num_samples=64, transformer=transformer, size=face_size, ) train_loader_embedding = DataLoader( train_dataset_embedding, num_workers=num_workers, batch_size=batch_size, ) val_dataset_embedding = FrameFaceIterableDataset( roots=val_roots, dfs=val_dfs, scale=face_policy, num_samples=64, transformer=transformer, size=face_size, ) val_loader_embedding = DataLoader( val_dataset_embedding, num_workers=num_workers, batch_size=batch_size, ) else: train_loader_embedding = None val_loader_embedding = None stop = False while not stop: # Training optimizer.zero_grad() train_loss = train_num = 0 for train_batch in tqdm(train_loader, desc='Epoch {:03d}'.format(epoch), leave=False, total=len(train_loader) // train_loader.batch_size): net.train() train_batch_num = len(train_batch[0]) train_num += train_batch_num train_batch_loss = batch_forward(net, device, criterion, train_batch) if torch.isnan(train_batch_loss): raise ValueError('NaN loss') train_loss += train_batch_loss.item() * train_batch_num # Optimization train_batch_loss.backward() optimizer.step() optimizer.zero_grad() # Logging if iteration > 0 and (iteration % log_interval == 0): train_loss /= train_num tb.add_scalar('train/loss', train_loss, iteration) tb.add_scalar('lr', optimizer.param_groups[0]['lr'], iteration) tb.add_scalar('epoch', epoch, iteration) # Checkpoint save_model(net, optimizer, train_loss, val_loss, iteration, batch_size, epoch, last_path) train_loss = train_num = 0 # Validation if iteration > 0 and (iteration % validation_interval == 0): # Validation val_loss = validation_routine(net, device, val_loader, criterion, tb, iteration, tag='val') tb.flush() # LR Scheduler lr_scheduler.step(val_loss) # Model checkpoint save_model(net, optimizer, train_loss, val_loss, iteration, batch_size, epoch, periodic_path.format(iteration)) if val_loss < min_val_loss: min_val_loss = val_loss shutil.copy(periodic_path.format(iteration), bestval_path) # Attention if enable_attention and hasattr(net, 'feat_ext') and hasattr( net.feat_ext, 'get_attention'): net.eval() # For each dataframe show the attention for a real,fake couple of frames for df, root, sample_idx, tag in [ (train_dfs[0], train_roots[0], train_dfs[0][train_dfs[0]['label'] == False].index[0], 'train/att/real'), (train_dfs[0], train_roots[0], train_dfs[0][train_dfs[0]['label'] == True].index[0], 'train/att/fake'), ]: record = df.loc[sample_idx] tb_attention(tb, tag, iteration, net.feat_ext, device, face_size, face_policy, transformer, root, record) if optimizer.param_groups[0]['lr'] <= min_lr: print('Reached minimum learning rate. Stopping.') stop = True break # Embedding visualization if enable_embedding: if iteration > 0 and (iteration % embedding_interval == 0): embedding_routine(net=net, device=device, loader=train_loader_embedding, iteration=iteration, tb=tb, tag=tag + '/train') embedding_routine(net=net, device=device, loader=val_loader_embedding, iteration=iteration, tb=tb, tag=tag + '/val') iteration += 1 if iteration > max_num_iterations: print('Maximum number of iterations reached') stop = True break # End of iteration epoch += 1 # Needed to flush out last events tb.close() print('Completed')