예제 #1
0
def main():
    opt = BaseOptions().parse()

    if opt.test_type == 'video' or opt.test_type == 'image':
        import Test_Gen_Models.Test_Video_Model as Gen_Model
        from Dataloader.Test_load_video import Test_VideoFolder
    elif opt.test_type == 'audio':
        import Test_Gen_Models.Test_Audio_Model as Gen_Model
        from Dataloader.Test_load_audio import Test_VideoFolder
    else:
        raise ('test type select error')

    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.sequence_length = 1
    test_nums = [1, 2, 3, 4]  # choose input identity images

    model = Gen_Model.GenModel(opt)
    # _, _, start_epoch = util.load_test_checkpoint(opt.test_resume_path, model)
    start_epoch = opt.start_epoch
    visualizer = Visualizer(opt)
    # find the checkpoint's path name without the 'checkpoint.pth.tar'
    path_name = ntpath.basename(opt.test_resume_path)[:-19]
    web_dir = os.path.join(opt.results_dir, path_name,
                           '%s_%s' % ('test', start_epoch))
    for i in test_nums:
        A_path = os.path.join(opt.test_A_path, 'test_sample' + str(i) + '.jpg')
        test_folder = Test_VideoFolder(root=opt.test_root,
                                       A_path=A_path,
                                       config=opt)
        test_dataloader = DataLoader(test_folder,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=1)
        model, _, start_epoch = util.load_test_checkpoint(
            opt.test_resume_path, model)

        # inference during test

        for i2, data in enumerate(test_dataloader):
            if i2 < 5:
                model.set_test_input(data)
                model.test_train()

        # test
        start = time.time()
        for i3, data in enumerate(test_dataloader):
            model.set_test_input(data)
            model.test()
            visuals = model.get_current_visuals()
            img_path = model.get_image_paths()
            visualizer.save_images_test(web_dir, visuals, img_path, i3,
                                        opt.test_num)
        end = time.time()
        print('finish processing in %03f seconds' % (end - start))
예제 #2
0
import torch
import torch.nn as nn
import torch.nn.functional as F
from Options_all import BaseOptions
opt = BaseOptions().parse()


def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=strd,
                     padding=padding,
                     bias=bias)


class ConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(ConvBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, int(out_planes / 2))
        self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
        self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
        self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
        self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))

        if in_planes != out_planes:
            self.downsample = nn.Sequential(
                nn.BatchNorm2d(in_planes),
                nn.ReLU(True),
예제 #3
0
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import inspect, re
import torch.nn as nn
import os
from Options_all import BaseOptions
import collections

config = BaseOptions().parse()


def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor[0].cpu().float().numpy()
    image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
    PIL_image = image_numpy

    return PIL_image.astype(imtype)


def tensor2image(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor.cpu().float().numpy()
    image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
    PIL_image = image_numpy

    return PIL_image.astype(imtype)


def tensor2mfcc(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor[0].cpu().float().numpy()
예제 #4
0
import matplotlib.pyplot as plt
from matplotlib import cm
from torch.utils.data import DataLoader
from Dataloader.Test_load_audio import Test_VideoFolder
import os
from Options_all import BaseOptions
import matlab.engine
matlab = matlab.engine.start_matlab()

wf = wave.open('0572_0019_0003.wav', 'rb')
wav = wf.readframes(16000 * 10)
wav = wav[1::2]
wav = np.fromstring(wav, 'Int16')
wav = wav

opt = BaseOptions().parse()
opt.nThreads = 1  # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.sequence_length = 1
A_path = os.path.join(opt.test_A_path, 'test_sample' + str(3) + '.jpg')
test_folder = Test_VideoFolder(root='./0572_0019_0003',
                               A_path=A_path,
                               config=opt)
test_dataloader = DataLoader(test_folder, batch_size=1)

enum = list(enumerate(test_dataloader))

mfcc_bin = enum[0][1]['B_audio'].numpy()[0][0][0]

mfcc_feat = mfcc(wav[0:10000],
                 16000,
import time
from Options_all import BaseOptions
from util import util
from util.visualizer import Visualizer
from torch.utils.data import DataLoader
import os
import ntpath

opt = BaseOptions().parse()
if opt.test_type == 'video' or 'image':
    import Test_Gen_Models.Test_Video_Model as Gen_Model
    from Dataloader.Test_load_video import Test_VideoFolder
elif opt.test_type == 'audio':
    import Test_Gen_Models.Test_Audio_Model as Gen_Model
    from Dataloader.Test_load_audio import Test_VideoFolder
else:
    raise('test type select error')

opt.nThreads = 1   # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.sequence_length = 1
test_nums = [1, 2, 3, 4]  # choose input identity images

model = Gen_Model.GenModel(opt)
# _, _, start_epoch = util.load_test_checkpoint(opt.test_resume_path, model)
start_epoch = opt.start_epoch
visualizer = Visualizer(opt)
# find the checkpoint's path name without the 'checkpoint.pth.tar'
path_name = ntpath.basename(opt.test_resume_path)[:-19]
web_dir = os.path.join(opt.results_dir, path_name, '%s_%s' % ('test', start_epoch))
for i in test_nums: