Beispiel #1
0
def evaluate(epoch, ema):
    avg_loss = 0.0
    epoch_time = 0
    # progbar = Progbar(len(val_loader.dataset) // c.eval_batch_size)
    ema_model = FFTNetModel(hid_channels=256,
                            out_channels=256,
                            n_layers=c.num_quant,
                            cond_channels=80)
    ema_model = ema.assign_ema_model(model, ema_model, use_cuda)
    ema_model.eval()
    with torch.no_grad():
        for num_iter, batch in enumerate(train_loader):
            start_time = time.time()
            wav = batch[0].unsqueeze(1)
            mel = batch[1].transpose(1, 2)
            lens = batch[2]
            target = batch[3]
            if use_cuda:
                wav = wav.cuda()
                mel = mel.cuda()
                target = target.cuda()
            current_step = num_iter + epoch * len(train_loader) + 1
            out = ema_model(wav, mel)
            loss, fp, tp = criterion(out, target, lens)
            step_time = time.time() - start_time
            epoch_time += step_time
            avg_loss += loss.item()
    avg_loss /= num_iter
    return avg_loss
from tqdm import tqdm
from model import FFTNet, FFTNetModel
from generic_utils import count_parameters

torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if use_cuda:
    torch.backends.cudnn.benchmark = False

print(" ---- Test FFTNetModel step forward ----")
net = FFTNetModel(hid_channels=256,
                  out_channels=256,
                  n_layers=11,
                  cond_channels=80)
net.eval()
print(" > Number of model params: ", count_parameters(net))
x = torch.rand(1, 1, 1)
cx = torch.rand(1, 80, 1)
time_start = time.time()
with torch.no_grad():
    for i in tqdm(range(20000)):
        out = net.forward_step(x, cx)
    time_avg = (time.time() - time_start) / 20000
    print("> Avg time per step inference on CPU: {}".format(time_avg))

# on GPU
net = FFTNetModel(hid_channels=256,
                  out_channels=256,
                  n_layers=11,
                  cond_channels=80)