Example #1
0
def sdn_loss(output, label, coeffs=None):
    total_loss = 0.0
    if coeffs is None:
        coeffs = [1 for _ in range(len(output) - 1)]
    for ic_id in range(len(output) - 1):
        total_loss += float(coeffs[ic_id]) * af.get_loss_criterion()(
            output[ic_id], label)
    total_loss += af.get_loss_criterion()(output[-1], label)
    return total_loss
Example #2
0
def sdn_training_step(optimizer, model, coeffs, batch, device):
    b_x = batch[0].to(device)
    b_y = batch[1].to(device)
    output = model(b_x)
    optimizer.zero_grad()  #clear gradients for this training step
    total_loss = 0.0

    for ic_id in range(model.num_output - 1):
        cur_output = output[ic_id]
        cur_loss = float(coeffs[ic_id])*af.get_loss_criterion()(cur_output, b_y)
        total_loss += cur_loss

    total_loss += af.get_loss_criterion()(output[-1], b_y)
    total_loss.backward()
    optimizer.step()                # apply gradients

    return total_loss
Example #3
0
def cnn_training_step(model, optimizer, data, labels, device='cpu'):
    b_x = data.to(device)   # batch x
    b_y = labels.to(device)   # batch y
    output = model(b_x)            # cnn final output
    criterion = af.get_loss_criterion()
    loss = criterion(output, b_y)   # cross entropy loss
    optimizer.zero_grad()           # clear gradients for this training step
    loss.backward()                 # backpropagation, compute gradients
    optimizer.step()                # apply gradients
Example #4
0
def sdn_ic_only_step(optimizer, model, batch, device):
    b_x = batch[0].to(device)
    b_y = batch[1].to(device)
    output = model(b_x)
    optimizer.zero_grad()  #clear gradients for this training step
    
    total_loss = 0.0
    for output_id, cur_output in enumerate(output):
        if output_id == model.num_output - 1: # last output
            break
        
        cur_loss = af.get_loss_criterion()(cur_output, b_y)
        total_loss += cur_loss

    total_loss.backward()
    optimizer.step()                # apply gradients

    return total_loss