Esempio n. 1
0
def main():
    # Args
    parser = argparse.ArgumentParser()

    parser.add_argument('--testsets', type=str, help='Testing datasets', nargs='+', choices=split.available_datasets,
                        required=True)
    parser.add_argument('--testsplits', type=str, help='Test split', nargs='+', default=['val', 'test'],
                        choices=['train', 'val', 'test'])

    # Alternative 1: Specify training params
    parser.add_argument('--net', type=str, help='Net model class')
    parser.add_argument('--traindb', type=str, action='store', help='Dataset used for training')
    parser.add_argument('--face', type=str, help='Face crop or scale', default='scale',
                        choices=['scale', 'tight'])
    parser.add_argument('--size', type=int, help='Train patch size')

    weights_group = parser.add_mutually_exclusive_group(required=True)
    weights_group.add_argument('--weights', type=Path, help='Weight filename', default='bestval.pth')

    # Alternative 2: Specify trained model path
    weights_group.add_argument('--model_path', type=Path, help='Full path of the trained model')

    # Common params
    parser.add_argument('--batch', type=int, help='Batch size to fit in GPU memory', default=128)

    parser.add_argument('--workers', type=int, help='Num workers for data loaders', default=6)
    parser.add_argument('--device', type=int, help='GPU id', default=0)
    parser.add_argument('--seed', type=int, help='Random seed used for training', default=0)

    parser.add_argument('--debug', action='store_true', help='Debug flag', )
    parser.add_argument('--suffix', type=str, help='Suffix to default tag')

    parser.add_argument('--models_dir', type=Path, help='Folder with trained models',
                        default='weights/')

    parser.add_argument('--num_video', type=int, help='Number of real-fake videos to test')
    parser.add_argument('--results_dir', type=Path, help='Output folder',
                        default='results/')

    parser.add_argument('--override', action='store_true', help='Override existing results', )

    args = parser.parse_args()

    device = torch.device('cuda:{}'.format(args.device)) if torch.cuda.is_available() else torch.device('cpu')
    patch_size: int = args.size
    num_workers: int = args.workers
    batch_size: int = args.batch
    net_name: str = args.net
    weights: Path = args.weights
    suffix: str = args.suffix
    face_policy: str = args.face
    models_dir: Path = args.models_dir
    max_num_videos_per_label: int = args.num_video  # number of real-fake videos to test
    model_path: Path = args.model_path
    results_dir: Path = args.results_dir
    debug: bool = args.debug
    override: bool = args.override
    train_datasets = args.traindb
    seed: int = args.seed
    test_sets = args.testsets
    test_splits = args.testsplits

    if model_path is None:
        if net_name is None:
            raise RuntimeError('Net name is required if \"model_path\" is not provided')

        model_name = utils.make_train_tag(net_class=getattr(fornet, net_name),
                                          traindb=train_datasets,
                                          face_policy=face_policy,
                                          patch_size=patch_size,
                                          seed=seed,
                                          suffix=suffix,
                                          debug=debug,
                                          )
        model_path = models_dir.joinpath(model_name, weights)

    else:
        # get arguments from the model path
        face_policy = str(model_path).split('face-')[1].split('_')[0]
        patch_size = int(str(model_path).split('size-')[1].split('_')[0])
        net_name = str(model_path).split('net-')[1].split('_')[0]
        model_name = '_'.join(model_path.with_suffix('').parts[-2:])

    # Load net
    net_class = getattr(fornet, net_name)

    # load model
    print('Loading model...')
    state_tmp = torch.load(model_path, map_location='cpu')
    if 'net' not in state_tmp.keys():
        state = OrderedDict({'net': OrderedDict()})
        [state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
    else:
        state = state_tmp
    net: FeatureExtractor = net_class().eval().to(device)

    incomp_keys = net.load_state_dict(state['net'], strict=True)
    print(incomp_keys)
    print('Model loaded!')

    # val loss per-frame
    criterion = nn.BCEWithLogitsLoss(reduction='none')

    # Define data transformers
    test_transformer = utils.get_transformer(face_policy, patch_size, net.get_normalizer(), train=False)

    # datasets and dataloaders (from train_binclass.py)
    print('Loading data...')
    splits = split.make_splits(dbs={'train': test_sets, 'val': test_sets, 'test': test_sets})
    train_dfs = [splits['train'][db][0] for db in splits['train']]
    train_roots = [splits['train'][db][1] for db in splits['train']]
    val_roots = [splits['val'][db][1] for db in splits['val']]
    val_dfs = [splits['val'][db][0] for db in splits['val']]
    test_dfs = [splits['test'][db][0] for db in splits['test']]
    test_roots = [splits['test'][db][1] for db in splits['test']]

    # Output paths
    out_folder = results_dir.joinpath(model_name)
    out_folder.mkdir(mode=0o775, parents=True, exist_ok=True)

    # Samples selection
    if max_num_videos_per_label is not None:
        dfs_out_train = [select_videos(df, max_num_videos_per_label) for df in train_dfs]
        dfs_out_val = [select_videos(df, max_num_videos_per_label) for df in val_dfs]
        dfs_out_test = [select_videos(df, max_num_videos_per_label) for df in test_dfs]
    else:
        dfs_out_train = train_dfs
        dfs_out_val = val_dfs
        dfs_out_test = test_dfs

    # Extractions list
    extr_list = []
    # Append train and validation set first
    if 'train' in test_splits:
        for idx, dataset in enumerate(test_sets):
            extr_list.append(
                (dfs_out_train[idx], out_folder.joinpath(dataset + '_train.pkl'), train_roots[idx], dataset + ' TRAIN')
            )
    if 'val' in test_splits:
        for idx, dataset in enumerate(test_sets):
            extr_list.append(
                (dfs_out_val[idx], out_folder.joinpath(dataset + '_val.pkl'), val_roots[idx], dataset + ' VAL')
            )
    if 'test' in test_splits:
        for idx, dataset in enumerate(test_sets):
            extr_list.append(
                (dfs_out_test[idx], out_folder.joinpath(dataset + '_test.pkl'), test_roots[idx], dataset + ' TEST')
            )

    for df, df_path, df_root, tag in extr_list:
        if override or not df_path.exists():
            print('\n##### PREDICT VIDEOS FROM {} #####'.format(tag))
            print('Real frames: {}'.format(sum(df['label'] == False)))
            print('Fake frames: {}'.format(sum(df['label'] == True)))
            print('Real videos: {}'.format(df[df['label'] == False]['video'].nunique()))
            print('Fake videos: {}'.format(df[df['label'] == True]['video'].nunique()))
            dataset_out = process_dataset(root=df_root, df=df, net=net, criterion=criterion,
                                          patch_size=patch_size,
                                          face_policy=face_policy, transformer=test_transformer,
                                          batch_size=batch_size,
                                          num_workers=num_workers, device=device, )
            df['score'] = dataset_out['score'].astype(np.float32)
            df['loss'] = dataset_out['loss'].astype(np.float32)
            print('Saving results to: {}'.format(df_path))
            df.to_pickle(str(df_path))

            if debug:
                plt.figure()
                plt.title(tag)
                plt.hist(df[df.label == True].score, bins=100, alpha=0.6, label='FAKE frames')
                plt.hist(df[df.label == False].score, bins=100, alpha=0.6, label='REAL frames')
                plt.legend()

            del (dataset_out)
            del (df)
            gc.collect()

    if debug:
        plt.show()

    print('Completed!')
Esempio n. 2
0
def main():
    # Args
    parser = argparse.ArgumentParser()
    parser.add_argument('--net',
                        type=str,
                        help='Net model class',
                        required=True)
    parser.add_argument('--traindb',
                        type=str,
                        help='Training datasets',
                        nargs='+',
                        choices=split.available_datasets,
                        required=True)
    parser.add_argument('--valdb',
                        type=str,
                        help='Validation datasets',
                        nargs='+',
                        choices=split.available_datasets,
                        required=True)
    parser.add_argument('--face',
                        type=str,
                        help='Face crop or scale',
                        required=True,
                        choices=['scale', 'tight'])
    parser.add_argument('--size',
                        type=int,
                        help='Train patch size',
                        required=True)

    parser.add_argument('--batch',
                        type=int,
                        help='Batch size to fit in GPU memory',
                        default=12)
    parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
    parser.add_argument('--valint',
                        type=int,
                        help='Validation interval (iterations)',
                        default=500)
    parser.add_argument(
        '--patience',
        type=int,
        help='Patience before dropping the LR [validation intervals]',
        default=10)
    parser.add_argument('--maxiter',
                        type=int,
                        help='Maximum number of iterations',
                        default=20000)
    parser.add_argument('--init', type=str, help='Weight initialization file')
    parser.add_argument('--scratch',
                        action='store_true',
                        help='Train from scratch')

    parser.add_argument('--traintriplets',
                        type=int,
                        help='Limit the number of train triplets per epoch',
                        default=-1)
    parser.add_argument(
        '--valtriplets',
        type=int,
        help='Limit the number of validation triplets per epoch',
        default=2000)

    parser.add_argument('--logint',
                        type=int,
                        help='Training log interval (iterations)',
                        default=100)
    parser.add_argument('--workers',
                        type=int,
                        help='Num workers for data loaders',
                        default=6)
    parser.add_argument('--device', type=int, help='GPU device id', default=0)
    parser.add_argument('--seed', type=int, help='Random seed', default=0)

    parser.add_argument('--debug', action='store_true', help='Activate debug')
    parser.add_argument('--suffix', type=str, help='Suffix to default tag')

    parser.add_argument('--attention',
                        action='store_true',
                        help='Enable Tensorboard log of attention masks')
    parser.add_argument('--embedding',
                        action='store_true',
                        help='Activate embedding visualization in TensorBoard')
    parser.add_argument('--embeddingint',
                        type=int,
                        help='Embedding visualization interval in TensorBoard',
                        default=5000)

    parser.add_argument('--log_dir',
                        type=str,
                        help='Directory for saving the training logs',
                        default='runs/triplet/')
    parser.add_argument('--models_dir',
                        type=str,
                        help='Directory for saving the models weights',
                        default='weights/triplet/')

    args = parser.parse_args()

    # Parse arguments
    net_class = getattr(tripletnet, args.net)
    train_datasets = args.traindb
    val_datasets = args.valdb
    face_policy = args.face
    face_size = args.size

    batch_size = args.batch
    initial_lr = args.lr
    validation_interval = args.valint
    patience = args.patience
    max_num_iterations = args.maxiter
    initial_model = args.init
    train_from_scratch = args.scratch

    max_train_triplets = args.traintriplets
    max_val_triplets = args.valtriplets

    log_interval = args.logint
    num_workers = args.workers
    device = torch.device('cuda:{:d}'.format(
        args.device)) if torch.cuda.is_available() else torch.device('cpu')
    seed = args.seed

    debug = args.debug
    suffix = args.suffix

    enable_attention = args.attention
    enable_embedding = args.embedding
    embedding_interval = args.embeddingint

    weights_folder = args.models_dir
    logs_folder = args.log_dir

    # Random initialization
    np.random.seed(seed)
    torch.random.manual_seed(seed)

    # Load net
    net: nn.Module = net_class().to(device)

    # Loss and optimizers
    criterion = nn.TripletMarginLoss()

    min_lr = initial_lr * 1e-5
    optimizer = optim.Adam(net.get_trainable_parameters(), lr=initial_lr)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer,
        mode='min',
        factor=0.1,
        patience=patience,
        cooldown=2 * patience,
        min_lr=min_lr,
    )

    tag = utils.make_train_tag(
        net_class=net_class,
        traindb=train_datasets,
        face_policy=face_policy,
        patch_size=face_size,
        seed=seed,
        suffix=suffix,
        debug=debug,
    )

    # Model checkpoint paths
    bestval_path = os.path.join(weights_folder, tag, 'bestval.pth')
    last_path = os.path.join(weights_folder, tag, 'last.pth')
    periodic_path = os.path.join(weights_folder, tag, 'it{:06d}.pth')

    os.makedirs(os.path.join(weights_folder, tag), exist_ok=True)

    # Load model
    val_loss = min_val_loss = 20
    epoch = iteration = 0
    net_state = None
    opt_state = None
    if initial_model is not None:
        # If given load initial model
        print('Loading model form: {}'.format(initial_model))
        state = torch.load(initial_model, map_location='cpu')
        net_state = state['net']
    elif not train_from_scratch and os.path.exists(last_path):
        print('Loading model form: {}'.format(last_path))
        state = torch.load(last_path, map_location='cpu')
        net_state = state['net']
        opt_state = state['opt']
        iteration = state['iteration'] + 1
        epoch = state['epoch']
    if not train_from_scratch and os.path.exists(bestval_path):
        state = torch.load(bestval_path, map_location='cpu')
        min_val_loss = state['val_loss']
    if net_state is not None:
        adapt_binclass_model(net_state)
        incomp_keys = net.load_state_dict(net_state, strict=False)
        print(incomp_keys)
    if opt_state is not None:
        for param_group in opt_state['param_groups']:
            param_group['lr'] = initial_lr
        optimizer.load_state_dict(opt_state)

    # Initialize Tensorboard
    logdir = os.path.join(logs_folder, tag)
    if iteration == 0:
        # If training from scratch or initialization remove history if exists
        shutil.rmtree(logdir, ignore_errors=True)

    # TensorboardX instance
    tb = SummaryWriter(logdir=logdir)
    if iteration == 0:
        dummy = torch.randn((1, 3, face_size, face_size), device=device)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            tb.add_graph(net, [dummy, dummy, dummy], verbose=False)

    transformer = utils.get_transformer(face_policy=face_policy,
                                        patch_size=face_size,
                                        net_normalizer=net.get_normalizer(),
                                        train=True)

    # Datasets and data loaders
    print('Loading data')
    splits = split.make_splits(dbs={
        'train': train_datasets,
        'val': val_datasets
    })
    train_dfs = [splits['train'][db][0] for db in splits['train']]
    train_roots = [splits['train'][db][1] for db in splits['train']]
    val_roots = [splits['val'][db][1] for db in splits['val']]
    val_dfs = [splits['val'][db][0] for db in splits['val']]

    train_dataset = FrameFaceTripletIterableDataset(
        roots=train_roots,
        dfs=train_dfs,
        scale=face_policy,
        num_triplets=max_train_triplets,
        transformer=transformer,
        size=face_size,
    )

    val_dataset = FrameFaceTripletIterableDataset(
        roots=val_roots,
        dfs=val_dfs,
        scale=face_policy,
        num_triplets=max_val_triplets,
        transformer=transformer,
        size=face_size,
    )

    train_loader = DataLoader(
        train_dataset,
        num_workers=num_workers,
        batch_size=batch_size,
    )

    val_loader = DataLoader(
        val_dataset,
        num_workers=num_workers,
        batch_size=batch_size,
    )

    print('Training triplets: {}'.format(len(train_dataset)))
    print('Validation triplets: {}'.format(len(val_dataset)))

    if len(train_dataset) == 0:
        print('No training triplets. Halt.')
        return

    if len(val_dataset) == 0:
        print('No validation triplets. Halt.')
        return

    # Embedding visualization
    if enable_embedding:
        train_dataset_embedding = FrameFaceIterableDataset(
            roots=train_roots,
            dfs=train_dfs,
            scale=face_policy,
            num_samples=64,
            transformer=transformer,
            size=face_size,
        )
        train_loader_embedding = DataLoader(
            train_dataset_embedding,
            num_workers=num_workers,
            batch_size=batch_size,
        )
        val_dataset_embedding = FrameFaceIterableDataset(
            roots=val_roots,
            dfs=val_dfs,
            scale=face_policy,
            num_samples=64,
            transformer=transformer,
            size=face_size,
        )
        val_loader_embedding = DataLoader(
            val_dataset_embedding,
            num_workers=num_workers,
            batch_size=batch_size,
        )

    else:
        train_loader_embedding = None
        val_loader_embedding = None

    stop = False
    while not stop:

        # Training
        optimizer.zero_grad()

        train_loss = train_num = 0
        for train_batch in tqdm(train_loader,
                                desc='Epoch {:03d}'.format(epoch),
                                leave=False,
                                total=len(train_loader) //
                                train_loader.batch_size):
            net.train()
            train_batch_num = len(train_batch[0])
            train_num += train_batch_num

            train_batch_loss = batch_forward(net, device, criterion,
                                             train_batch)

            if torch.isnan(train_batch_loss):
                raise ValueError('NaN loss')

            train_loss += train_batch_loss.item() * train_batch_num

            # Optimization
            train_batch_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Logging
            if iteration > 0 and (iteration % log_interval == 0):
                train_loss /= train_num
                tb.add_scalar('train/loss', train_loss, iteration)
                tb.add_scalar('lr', optimizer.param_groups[0]['lr'], iteration)
                tb.add_scalar('epoch', epoch, iteration)

                # Checkpoint
                save_model(net, optimizer, train_loss, val_loss, iteration,
                           batch_size, epoch, last_path)
                train_loss = train_num = 0

            # Validation
            if iteration > 0 and (iteration % validation_interval == 0):

                # Validation
                val_loss = validation_routine(net,
                                              device,
                                              val_loader,
                                              criterion,
                                              tb,
                                              iteration,
                                              tag='val')
                tb.flush()

                # LR Scheduler
                lr_scheduler.step(val_loss)

                # Model checkpoint
                save_model(net, optimizer, train_loss, val_loss, iteration,
                           batch_size, epoch, periodic_path.format(iteration))
                if val_loss < min_val_loss:
                    min_val_loss = val_loss
                    shutil.copy(periodic_path.format(iteration), bestval_path)

                # Attention
                if enable_attention and hasattr(net, 'feat_ext') and hasattr(
                        net.feat_ext, 'get_attention'):
                    net.eval()
                    # For each dataframe show the attention for a real,fake couple of frames

                    for df, root, sample_idx, tag in [
                        (train_dfs[0], train_roots[0],
                         train_dfs[0][train_dfs[0]['label'] == False].index[0],
                         'train/att/real'),
                        (train_dfs[0], train_roots[0],
                         train_dfs[0][train_dfs[0]['label'] == True].index[0],
                         'train/att/fake'),
                    ]:
                        record = df.loc[sample_idx]
                        tb_attention(tb, tag, iteration, net.feat_ext, device,
                                     face_size, face_policy, transformer, root,
                                     record)

                if optimizer.param_groups[0]['lr'] <= min_lr:
                    print('Reached minimum learning rate. Stopping.')
                    stop = True
                    break

            # Embedding visualization
            if enable_embedding:
                if iteration > 0 and (iteration % embedding_interval == 0):
                    embedding_routine(net=net,
                                      device=device,
                                      loader=train_loader_embedding,
                                      iteration=iteration,
                                      tb=tb,
                                      tag=tag + '/train')
                    embedding_routine(net=net,
                                      device=device,
                                      loader=val_loader_embedding,
                                      iteration=iteration,
                                      tb=tb,
                                      tag=tag + '/val')

            iteration += 1

            if iteration > max_num_iterations:
                print('Maximum number of iterations reached')
                stop = True
                break

            # End of iteration

        epoch += 1

    # Needed to flush out last events
    tb.close()

    print('Completed')