示例#1
0
        'frame': row.frame,
        'path': os.path.join(path_test_dir, row.path)
    } for _, row in test_dataset_paths.iterrows()]

    image_dataset = TestAntispoofDataset(paths=paths,
                                         output_shape=output_shape)
    dataloader = DataLoader(image_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            num_workers=8)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # load model
    # model = Model(base_model=resnet34(pretrained=False))
    model = DoubleLossModelTwoHead(
        base_model=EfficientNet.from_name('efficientnet-b3')).to(device)
    model.load_state_dict(torch.load(PATH_MODEL, map_location=device))
    # model = torch.load(PATH_MODEL)
    model = model.to(device)
    model.eval()

    if USE_TTA:
        samples, frames, probabilities1, probabilities2, probabilities3 = [], [], [], [], []
        with torch.no_grad():
            for video, frame, batch in dataloader:
                batch = batch.to(device)
                _, prob1 = model(batch)
                _, prob2 = model(batch.flip(2))  # Vertical
                _, prob3 = model(batch.flip(3))  # Horizontal
                # _, prob4 = model(batch)
def train(model_name, optim='adam'):
    train_dataset = PretrainDataset(output_shape=config['image_resolution'])
    train_loader = DataLoader(train_dataset,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)

    val_dataset = IDRND_dataset_CV(fold=0,
                                   mode=config['mode'].replace('train', 'val'),
                                   double_loss_mode=True,
                                   output_shape=config['image_resolution'])
    val_loader = DataLoader(val_dataset,
                            batch_size=config['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)

    if model_name == 'EF':
        model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained(
            'efficientnet-b3')).to(device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.0090592697255896_1.0.pth"
            ))
    elif model_name == 'EFGAP':
        model = DoubleLossModelTwoHead(
            base_model=EfficientNetGAP.from_pretrained('efficientnet-b3')).to(
                device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.3281182915644134_1.0.pth"
            ))

    criterion = FocalLoss(add_weight=False).to(device)
    criterion4class = CrossEntropyLoss().to(device)

    if optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['learning_rate'],
                                     weight_decay=config['weight_decay'])
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=False)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    momentum=0.9,
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=True)

    steps_per_epoch = train_loader.__len__() - 15
    swa = SWA(optimizer,
              swa_start=config['swa_start'] * steps_per_epoch,
              swa_freq=int(config['swa_freq'] * steps_per_epoch),
              swa_lr=config['learning_rate'] / 10)
    scheduler = ExponentialLR(swa, gamma=0.9)
    # scheduler = StepLR(swa, step_size=5*steps_per_epoch, gamma=0.5)

    global_step = 0
    for epoch in trange(10):
        if epoch < 5:
            scheduler.step()
            continue
        model.train()
        train_bar = tqdm(train_loader)
        train_bar.set_description_str(desc=f"N epochs - {epoch}")

        for step, batch in enumerate(train_bar):
            global_step += 1
            image = batch['image'].to(device)
            label4class = batch['label0'].to(device)
            label = batch['label1'].to(device)

            output4class, output = model(image)
            loss4class = criterion4class(output4class, label4class)
            loss = criterion(output.squeeze(), label)
            swa.zero_grad()
            total_loss = loss4class * 0.5 + loss * 0.5
            total_loss.backward()
            swa.step()
            train_writer.add_scalar(tag="learning_rate",
                                    scalar_value=scheduler.get_lr()[0],
                                    global_step=global_step)
            train_writer.add_scalar(tag="BinaryLoss",
                                    scalar_value=loss.item(),
                                    global_step=global_step)
            train_writer.add_scalar(tag="SoftMaxLoss",
                                    scalar_value=loss4class.item(),
                                    global_step=global_step)
            train_bar.set_postfix_str(f"Loss = {loss.item()}")
            try:
                train_writer.add_scalar(tag="idrnd_score",
                                        scalar_value=idrnd_score_pytorch(
                                            label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="far_score",
                                        scalar_value=far_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="frr_score",
                                        scalar_value=frr_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="accuracy",
                                        scalar_value=bce_accuracy(
                                            label, output),
                                        global_step=global_step)
            except Exception:
                pass

        if (epoch > config['swa_start']
                and epoch % 2 == 0) or (epoch == config['number_epochs'] - 1):
            swa.swap_swa_sgd()
            swa.bn_update(train_loader, model, device)
            swa.swap_swa_sgd()

        scheduler.step()
        evaluate(model, val_loader, epoch, model_name)