Beispiel #1
0
def main():


    weight_path = "weight"
    filename = opt.data + "_" + opt.compression + "_" + opt.mode
    if opt.data == "All":
        if opt.model == "xception":
            model = net(num_class=5)
        else:
            model = MesoInception(out_channel=5)
    else:
        if opt.model == "xception":
            model = net(num_class=2)
        else:
            model = MesoInception(out_channel=2)

    model = torch.nn.DataParallel(model, device_ids=[0]).to(device)
    if opt.model == "xception":
        model.load_state_dict(
            torch.load(os.path.join(weight_path, "model_%s_%d.pth" % (filename, opt.checkpoint)), map_location=device))

    else:
        model.load_state_dict(
            torch.load(os.path.join(weight_path, "model_Mesonet_%s_%d.pth" % (filename, opt.checkpoint)),
                       map_location=device))

    dataset_val = Face(mode='val', resize=256 if opt.model == "mesonet" else 299, filename=filename)
    dataset_test = Face(mode='test', resize=256 if opt.model == "mesonet" else 299, filename=filename)
    loader_val = DataLoader(dataset_val, batch_size=opt.batchsize, num_workers=8)
    loader_test = DataLoader(dataset_test, batch_size=opt.batchsize, num_workers=8)

    # print(next(iter(loader_val))[0].shape, next(iter(loader_val))[1].shape)
    # torch.Size([32, 5, 3, 224, 224])
    # torch.Size([32])

    if opt.data == "All":
        TP, TN, FP, FN, correct_0, correct_1, correct_2, correct_3, correct_4, total_0, total_1, total_2, total_3, total_4 = evaluate_all(model, loader_val)
        print("model_%s_%d val:\nOri: %d/%d = %.6f\nDeepfakes: %d/%d = %.6f\nFace2Face: %d/%d = %.6f\nFaceSwap: %d/%d = %.6f\nNeuralTextures: %d/%d = %.6f\nTP:%d, TN:%d, FP:%d, FN:%d\n acc:%.6f" % (
                filename, opt.checkpoint,
                correct_0, total_0, correct_0 / total_0, correct_1, total_1, correct_1 / total_1, correct_2, total_2,
                correct_2 / total_2, correct_3, total_3, correct_3 / total_3, correct_4, total_4, correct_4 / total_4,
                TP, TN, FP, FN, (TP+FN)/(TP+TN+FP+FN)))
        # TP, TN, FP, FN, correct_0, correct_1, correct_2, correct_3, correct_4, total_0, total_1, total_2, total_3, total_4 = evaluate_all(model, loader_test)
        # print("model_%s_%d test:\nOri: %d/%d = %.6f\nDeepfakes: %d/%d = %.6f\nFace2Face: %d/%d = %.6f\nFaceSwap: %d/%d = %.6f\nNeuralTextures: %d/%d = %.6f\nTP:%d, TN:%d, FP:%d, FN:%d\n acc:%.6f" % (
        #         filename, opt.checkpoint,
        #         correct_0, total_0, correct_0 / total_0, correct_1, total_1, correct_1 / total_1, correct_2, total_2,
        #         correct_2 / total_2, correct_3, total_3, correct_3 / total_3, correct_4, total_4, correct_4 / total_4,
        #         TP, TN, FP, FN, (TP + FN) / (TP + TN + FP + FN)))
    else:
        val_acc = evaluate(model, loader_val)
        # val_test = evaluate(model, loader_test)
        val_test = 100
        print("model_%s_%d:   val_acc:%.6f     test_acc:%.6f" % (filename, opt.checkpoint, val_acc, val_test))
def main():

    if opt.data == "All":
        model = net(num_class=5)
    else:
        model = net(num_class=2)
    model = torch.nn.DataParallel(model, device_ids=[0, 1]).to(device)
    optimier = optim.Adam(model.parameters(), lr=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimier, milestones=[5, 40], gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    filename = opt.data + "_" + opt.compression + "_" + opt.mode
    dataset_train = Face(mode='train', filename=filename)
    dataset_val = Face(mode='val', filename=filename)
    loader_train = DataLoader(dataset_train, batch_size=opt.batchsize, shuffle=True, num_workers=8, drop_last=True)
    loader_val = DataLoader(dataset_val, batch_size=opt.batchsize, num_workers=8)

    dataset_step = len(loader_train.dataset)/opt.batchsize

    viz = visdom.Visdom(port=13680)
    weight_path = "weight"
    if opt.checkpoint != 0:
        model.load_state_dict(torch.load(os.path.join(weight_path, "model_%s_%d.pth" %(filename, opt.checkpoint))))
        viz.line([0.2], [dataset_step * opt.checkpoint], win='loss', opts=dict(title='loss'))
        viz.line([0.9], [opt.checkpoint], win='val_acc', opts=dict(title='val_acc'))
        viz.line([0.9], [opt.checkpoint], win='train_acc', opts=dict(title='train_acc'))
    print('check point:%d' %(opt.checkpoint))

    if opt.checkpoint == 0:
        viz.line([0.2], [dataset_step * opt.checkpoint], win='loss', opts=dict(title='loss'))
        viz.line([0.9], [opt.checkpoint], win='val_acc', opts=dict(title='val_acc'))
        viz.line([0.9], [opt.checkpoint], win='train_acc', opts=dict(title='train_acc'))
    global_step = dataset_step*opt.checkpoint

    model.train()
    for epoch in range(opt.checkpoint, opt.epochs):
        train_acc_all = 0
        for step, (x,y) in enumerate(loader_train):
            x, y = x.to(device), y.to(device)
            y_hat = model(x)

            loss = criterion(y_hat, y)

            optimier.zero_grad()
            loss.backward()
            optimier.step()

            pred = y_hat.argmax(dim=1)
            train_acc_all += torch.eq(pred, y).sum().float().item()

            if step % 10 == 0 and step != 0:
                viz.line([loss.item()], [global_step], win='loss', update='append')
                print("[epoch %d][%d/%d]  \n all_loss: %.6f \n" %(epoch + 1, step, dataset_step, loss.item()))
            global_step +=1

        if epoch % 1 == 0:
            print("****************************\nepoch:{} train_acc:{}\n***************************************\n".format(epoch + 1, train_acc_all/((step+1)*opt.batchsize)))
            val_acc = evaluate(model, loader_val)
            print("****************************\nepoch:{} val_acc:{}\n***************************************\n".format(epoch + 1, val_acc))
            viz.line([val_acc], [epoch+1], win='val_acc', update='append')
            viz.line([train_acc_all/((step+1)*opt.batchsize)], [epoch+1], win='train_acc', update='append')
            torch.save(model.state_dict(), os.path.join(weight_path, "model_%s_%d.pth" %(filename, epoch+1)))
            with open("log/%s.txt" %(filename), "a+") as f:
                f.write("epoch%d:   train_acc:%.6f,   val_acc:%.6f \n" % (epoch+1, train_acc_all/((step+1)*opt.batchsize), val_acc))
            scheduler.step()
Beispiel #3
0
def main():

    weight_path = "weight"
    filename = opt.data + "_" + opt.compression + "_" + opt.mode
    if opt.data == "All":
        model_rgb = TSM(num_classes=5, n_segment=5)
        model_diff = TSM(num_classes=5, n_segment=4)
        model_xcep = net(num_class=5)
    else:
        model_rgb = TSM(num_classes=2, n_segment=5)
        model_diff = TSM(num_classes=2, n_segment=4)
        model_xcep = net(num_class=2)

    model_rgb = torch.nn.DataParallel(model_rgb, device_ids=[0]).to(device)
    model_diff = torch.nn.DataParallel(model_diff, device_ids=[0]).to(device)
    model_xcep = torch.nn.DataParallel(model_xcep, device_ids=[0]).to(device)

    model_rgb.load_state_dict(
        torch.load(os.path.join(
            weight_path,
            "tsm_%s_%s_%d.pth" % ('rgb', filename, opt.checkpoint_rgb)),
                   map_location=device))
    model_diff.load_state_dict(
        torch.load(os.path.join(
            weight_path,
            "tsm_%s_%s_%d.pth" % ('rgbdiff', filename, opt.checkpoint_diff)),
                   map_location=device))
    model_xcep.load_state_dict(
        torch.load(os.path.join(
            weight_path, "model_%s_%d.pth" % (filename, opt.checkpoint_xcep)),
                   map_location=device))

    dataset_val_rgb = Face(mode='val',
                           resize=224,
                           filename=filename,
                           modality='rgb')
    dataset_val_diff = Face(mode='val',
                            resize=224,
                            filename=filename,
                            modality='rgbdiff')
    dataset_val_xcep = Face(mode='val', resize=299, filename=filename)
    loader_val_rgb = DataLoader(dataset_val_rgb,
                                batch_size=opt.batchsize,
                                num_workers=8)
    loader_val_diff = DataLoader(dataset_val_diff,
                                 batch_size=opt.batchsize,
                                 num_workers=8)
    loader_val_xcep = DataLoader(dataset_val_xcep,
                                 batch_size=opt.batchsize,
                                 num_workers=8)

    if opt.data == "All":
        pass
        TP, TN, FP, FN, correct_0, correct_1, correct_2, correct_3, correct_4, total_0, total_1, total_2, total_3, total_4 = evaluate_all(
            model_rgb, model_diff, loader_val_rgb, loader_val_diff)
        print(
            "model_%s_%d val:\nOri: %d/%d = %.6f\nDeepfakes: %d/%d = %.6f\nFace2Face: %d/%d = %.6f\nFaceSwap: %d/%d = %.6f\nNeuralTextures: %d/%d = %.6f\nTP:%d, TN:%d, FP:%d, FN:%d\n acc:%.6f"
            %
            (filename, opt.checkpoint_rgb, correct_0, total_0, correct_0 /
             total_0, correct_1, total_1, correct_1 / total_1, correct_2,
             total_2, correct_2 / total_2, correct_3, total_3, correct_3 /
             total_3, correct_4, total_4, correct_4 / total_4, TP, TN, FP, FN,
             (TP + FN) / (TP + TN + FP + FN)))

    else:
        rgb_acc, diff_acc, xcep_acc, avg_acc = evaluate(
            model_rgb, model_diff, model_xcep, loader_val_rgb, loader_val_diff,
            loader_val_xcep)
        print(
            "model_%s_%d:   rgb_acc:%.6f  diff_acc:%.6f  xcep_acc:%.6f avg_acc:%.6f"
            % (filename, opt.checkpoint_rgb, rgb_acc, diff_acc, xcep_acc,
               avg_acc))