示例#1
0
def train_fx_net(model, optimizer, train_loader, train_sampler, epoch, loss_function=nn.CrossEntropyLoss(), device='cpu'):
    model.train()
    train_set_size = len(train_sampler)
    total_loss = 0
    total_correct = 0
    results = []

    for batch_idx, data in enumerate(train_loader):
        mels, labels, settings, filenames, indeces = data
        
        mels = mels.to(device)
        labels = labels.to(device)
        
        preds = model(mels) # pass batch
        # loss = F.cross_entropy(preds, labels) # calculate loss
        loss = loss_function(preds, labels)

        optimizer.zero_grad() # zero gradients otherwise get accumulated
        loss.backward() # calculate gradient
        optimizer.step() # update weights

        total_loss += loss.item()
        correct = utils.get_num_correct_labels(preds, labels)
        total_correct += correct
        
        for idx, filename in enumerate(filenames):
            results.append(
                (indeces[idx].item(), 
                 filename, 
                 preds[idx].argmax().item(),
                 labels[idx].item()))
        
        if batch_idx > 0 and batch_idx % 50 == 0:
            print('Train Epoch: {}\t[{}/{} ({:.0f}%)]\tTotal Loss: {:.4f}\tAvg Loss: {:.4f}'.format(
                        epoch, # epoch
                        batch_idx * len(labels), 
                        train_set_size,
                        100. * batch_idx / len(train_loader), # % completion
                        total_loss,
                        total_loss / (batch_idx * len(labels))))

    print('====> Epoch: {}\tTotal Loss: {:.4f}\t Avg Loss: {:.4f}\tCorrect: {:.0f}/{}\tPercentage Correct: {:.2f}'.format(
            epoch,
            total_loss,
            total_loss / train_set_size,
            total_correct,
            train_set_size,
            100 * total_correct / train_set_size))
    
    return total_loss, total_correct, results
示例#2
0
def test_fx_net(model, test_loader, test_sampler, loss_function=nn.CrossEntropyLoss(), device='cpu'):
    model.eval()
    test_set_size = len(test_sampler)
    total_loss = 0
    total_correct = 0
    results = []
    
    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            mels, labels, settings, filenames, indeces = data
            
            mels = mels.to(device)
            labels = labels.to(device)

            preds = model(mels) # Pass Batch
            # loss = F.cross_entropy(preds, labels) # Calculate Loss
            loss = loss_function(preds, labels)

            total_loss += loss.item()
            correct = utils.get_num_correct_labels(preds, labels)
            total_correct += correct
            
            for idx, filename in enumerate(filenames):
                results.append(
                    (indeces[idx].item(), 
                     filename, 
                     preds.argmax(dim=1)[idx].item(), 
                     labels[idx].item()))

    print('====> Test Loss: {:.4f}\t Avg Loss: {:.4f}\tCorrect: {:.0f}/{}\tPercentage Correct: {:.2f}'.format(
        total_loss,
        total_loss / test_set_size,
        total_correct,
        test_set_size,
        100 * total_correct / test_set_size))
    
    return total_loss, total_correct, results
示例#3
0
def test_cond_nets(model_fx, model_set, 
                    test_loader, test_sampler,
                    loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), 
                    device='cpu'):
    model_fx.eval()
    model_set.eval()
    test_set_size = len(test_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []
    
    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            mels, labels, settings, filenames, indeces = data

            mels = mels.to(device)
            labels = labels.to(device)
            settings = settings.to(device)

            # predictions and loss for FxNet 
            preds_fx = model_fx(mels)
            loss_fx = loss_function_fx(preds_fx, labels)

            total_loss_fx += loss_fx.item()
            correct_fx = utils.get_num_correct_labels(preds_fx, labels)
            total_correct_fx += correct_fx

            # predictions and loss for SettingsNet
            cond_set = preds_fx.argmax(dim=1) # calculate labels for conditioning of setnet 
            preds_set = model_set(mels, cond_set)  # pass batch and labels for conditioning
            loss_set = loss_function_set(preds_set, settings)  # calculate loss

            total_loss_set += loss_set.item()
            correct_set = utils.get_num_correct_settings(preds_set, settings)
            total_correct_set += correct_set

            # loss for both networks
            loss = loss_fx + loss_set
            total_loss += loss.item()

            correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
            total_correct += correct
            
            for idx, filename in enumerate(filenames):
                results.append(
                    (indeces[idx].item(), 
                    filename, 
                    preds_fx[idx].argmax().item(),
                    labels[idx].item(),
                    np.round(preds_set[idx].detach().numpy(), 3),
                    np.round(settings[idx].detach().numpy(), 3)))
    
    print('====> Test Loss: {:.4f}'
                '\t Avg Loss: {:.4f}'
                '\t Fx Loss: {:.4f}'
                '\t Set Loss: {:.4f}'
                '\n\t\tCorrect: {:.0f}/{:.0f}'
                '\tFx Correct: {:.0f}/{}'
                '\tSet Correct: {:.0f}/{}'
                '\n\t\tPercentage Correct: {:.2f}'
                '\tPercentage Fx Correct: {:.2f}'
                '\tPercentage Set Correct: {:.2f}'.format(
                    total_loss,
                    total_loss / test_set_size,
                    total_loss_fx,
                    total_loss_set,
                    total_correct,
                    test_set_size,
                    total_correct_fx,
                    test_set_size,
                    total_correct_set,
                    test_set_size,
                    100 * total_correct / test_set_size,
                    100 * total_correct_fx / test_set_size,
                    100 * total_correct_set / test_set_size))
    
    return total_loss, total_correct, results
示例#4
0
def train_cond_nets(model_fx, model_set, 
                    optimizer_fx, optimizer_set, 
                    train_loader, train_sampler, epoch,
                    loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), 
                    device='cpu'):
    model_fx.train()
    model_set.train()
    train_set_size = len(train_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []

    for batch_idx, data in enumerate(train_loader):
        mels, labels, settings, filenames, indeces = data
        
        mels = mels.to(device)
        labels = labels.to(device)
        settings = settings.to(device)
        
        # predictions, loss and gradient for FxNet 
        preds_fx = model_fx(mels)
        loss_fx = loss_function_fx(preds_fx, labels)

        optimizer_fx.zero_grad()
        loss_fx.backward()
        optimizer_fx.step()

        total_loss_fx += loss_fx.item()
        correct_fx = utils.get_num_correct_labels(preds_fx, labels)
        total_correct_fx += correct_fx
        
        # predictions, loss and gradient for SettingsNet
        cond_set = preds_fx.argmax(dim=1) # calculate labels for conditioning of setnet
        preds_set = model_set(mels, cond_set)  # pass batch and labels for conditioning
        loss_set = loss_function_set(preds_set, settings)  # calculate loss

        optimizer_set.zero_grad()
        loss_set.backward()
        optimizer_set.step()

        total_loss_set += loss_set.item()
        correct_set = utils.get_num_correct_settings(preds_set, settings)
        total_correct_set += correct_set

        # predictions and loss for both networks
        loss = loss_fx + loss_set
        total_loss += loss.item()

        correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
        total_correct += correct
        
        for idx, filename in enumerate(filenames):
            results.append(
                (indeces[idx].item(), 
                filename, 
                preds_fx[idx].argmax().item(),
                labels[idx].item(),
                np.round(preds_set[idx].detach().numpy(), 3),
                np.round(settings[idx].detach().numpy(), 3)))
    
        if batch_idx > 0 and batch_idx % 50 == 0:
            print('Train Epoch: {}\t[{}/{} ({:.0f}%)]\tTotal Loss: {:.4f}\tAvg Loss: {:.4f}'.format(
                        epoch, # epoch
                        batch_idx * len(labels), 
                        train_set_size,
                        100. * batch_idx / len(train_loader), # % completion
                        total_loss,
                        total_loss / (batch_idx * len(labels))))

    print('====> Epoch: {}'
                        '\tTotal Loss: {:.4f}'
                        '\t Avg Loss: {:.4f}'
                        '\t Fx Loss: {:.4f}'
                        '\t Set Loss: {:.4f}'
                        '\n\t\tCorrect: {:.0f}/{}'
                        '\tFx Correct: {:.0f}/{}'
                        '\tSet Correct: {:.0f}/{}'
                        '\n\t\tPercentage Correct: {:.2f}'
                        '\tPercentage Fx Correct: {:.2f}'
                        '\tPercentage Set Correct: {:.2f}'.format(
                            epoch,
                            total_loss,
                            total_loss / train_set_size,
                            total_loss_fx,
                            total_loss_set,
                            total_correct,
                            train_set_size,
                            total_correct_fx,
                            train_set_size,
                            total_correct_set,
                            train_set_size,
                            100 * total_correct / train_set_size,
                            100 * total_correct_fx / train_set_size,
                            100 * total_correct_set / train_set_size))
    
    return total_loss, total_correct, results
示例#5
0
def val_multi_net(model, val_loader, val_sampler, loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), device='cpu'):
    model.eval()
    val_set_size = len(val_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []

    with torch.no_grad():
        for batch_idx, data in enumerate(val_loader):
            mels, labels, settings, filenames, indeces = data
            
            mels = mels.to(device)
            labels = labels.to(device)
            settings = settings.to(device)
            
            preds_fx, preds_set = model(mels)
            loss_fx = loss_function_fx(preds_fx, labels)
            loss_set = loss_function_set(preds_set, settings)
            loss = loss_fx + loss_set

            total_loss_fx += loss_fx.item()
            total_loss_set += loss_set.item()
            total_loss += loss.item()
            correct_fx = utils.get_num_correct_labels(preds_fx, labels)
            correct_set = utils.get_num_correct_settings(preds_set, settings)
            correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
            total_correct_fx += correct_fx
            total_correct_set += correct_set
            total_correct += correct
            
            for idx, filename in enumerate(filenames):
                results.append(
                    (indeces[idx].item(), 
                    filename, 
                    preds_fx[idx].argmax().item(),
                    labels[idx].item(),
                    np.round(preds_set[idx].detach().numpy(), 3),
                    np.round(settings[idx].detach().numpy(), 3)))
    
    print('====> Val Loss: {:.4f}'
                '\t Avg Loss: {:.4f}'
                '\t Fx Loss: {:.4f}'
                '\t Set Loss: {:.4f}'
                '\n\t\tCorrect: {:.0f}/{:.0f}'
                '\tFx Correct: {:.0f}/{}'
                '\tSet Correct: {:.0f}/{}'
                '\n\t\tPercentage Correct: {:.2f}'
                '\tPercentage Fx Correct: {:.2f}'
                '\tPercentage Set Correct: {:.2f}'.format(
                    total_loss,
                    total_loss / val_set_size,
                    total_loss_fx,
                    total_loss_set,
                    total_correct,
                    val_set_size,
                    total_correct_fx,
                    val_set_size,
                    total_correct_set,
                    val_set_size,
                    100 * total_correct / val_set_size,
                    100 * total_correct_fx / val_set_size,
                    100 * total_correct_set / val_set_size))
    
    return total_loss, total_correct, results
示例#6
0
def train_multi_net(model, optimizer, train_loader, train_sampler, epoch, 
                    loss_function_fx=nn.CrossEntropyLoss(), loss_function_set=nn.L1Loss(), device='cpu'):
    model.train()
    train_set_size = len(train_sampler)
    total_loss_fx = 0
    total_loss_set = 0
    total_loss = 0
    total_correct_fx = 0
    total_correct_set = 0
    total_correct = 0
    results = []

    for batch_idx, data in enumerate(train_loader):
        mels, labels, settings, filenames, indeces = data
        
        mels = mels.to(device)
        labels = labels.to(device)
        settings = settings.to(device)
        
        preds_fx, preds_set = model(mels)

        loss_fx = loss_function_fx(preds_fx, labels)
        loss_set = loss_function_set(preds_set, settings)
        loss = loss_fx + loss_set

        optimizer.zero_grad() # zero gradients otherwise get accumulated
        loss.backward() # calculate gradient
        optimizer.step() # update weights

        total_loss_fx += loss_fx.item()
        total_loss_set += loss_set.item()
        total_loss += loss.item()
        correct_fx = utils.get_num_correct_labels(preds_fx, labels)
        correct_set = utils.get_num_correct_settings(preds_set, settings)
        correct = utils.get_num_correct(preds_fx, preds_set, labels, settings)
        total_correct_fx += correct_fx
        total_correct_set += correct_set
        total_correct += correct
        
        for idx, filename in enumerate(filenames):
            results.append(
                (indeces[idx].item(), 
                filename, 
                preds_fx[idx].argmax().item(),
                labels[idx].item(),
                np.round(preds_set[idx].detach().numpy(), 3),
                np.round(settings[idx].detach().numpy(), 3)))
        
        if batch_idx > 0 and batch_idx % 50 == 0:
            print('Train Epoch: {}\t[{}/{} ({:.0f}%)]\tTotal Loss: {:.4f}\tAvg Loss: {:.4f}'.format(
                        epoch, # epoch
                        batch_idx * len(labels), 
                        train_set_size,
                        100. * batch_idx / len(train_loader), # % completion
                        total_loss,
                        total_loss / (batch_idx * len(labels))))

    print('====> Epoch: {}'
                        '\tTotal Loss: {:.4f}'
                        '\t Avg Loss: {:.4f}'
                        '\t Fx Loss: {:.4f}'
                        '\t Set Loss: {:.4f}'
                        '\n\t\tCorrect: {:.0f}/{}'
                        '\tFx Correct: {:.0f}/{}'
                        '\tSet Correct: {:.0f}/{}'
                        '\n\t\tPercentage Correct: {:.2f}'
                        '\tPercentage Fx Correct: {:.2f}'
                        '\tPercentage Set Correct: {:.2f}'.format(
                            epoch,
                            total_loss,
                            total_loss / train_set_size,
                            total_loss_fx,
                            total_loss_set,
                            total_correct,
                            train_set_size,
                            total_correct_fx,
                            train_set_size,
                            total_correct_set,
                            train_set_size,
                            100 * total_correct / train_set_size,
                            100 * total_correct_fx / train_set_size,
                            100 * total_correct_set / train_set_size))
    
    return total_loss, total_correct, results