Ejemplo n.º 1
0
 def make_loss(self, pi, sigma, mu, labels=None, warmup=None, warmup_threshold=10):
     """
     The special loss for mdn
     :param logit: The output of the network
     :param labels: The ground truth labels
     :param warmup: The warmup process for the mean to get the range faster
     :return: the total loss
     """
     #return mdn.new_mdn_loss(pi, sigma, mu, labels)
     if warmup is None or warmup> warmup_threshold:                  # If no warmup is being done
         return mdn.mdn_loss(pi, sigma, mu, labels)
     else:
         print('warmup mode on')
         # Make sigma to 1 (wide)
         loss = 10*torch.mean(torch.pow(sigma-1, 2))
         # Make the mu converge to labels
         loss += 10*torch.mean(torch.pow(torch.mean(mu, 1) - labels, 2))
     return loss
Ejemplo n.º 2
0
def train(model, train_loader, num_ep, lear_rate):
    """
    Train MDN model using train dataset

    Args:
        model:          pytorch MDN model to be trained
        train_loader:   loader for the train dataset
        num_ep:         number for epochs for the training

    Returns:
        model:          pytorch MDN trained model

    """

    optimizer = optim.Adam(model.parameters(),lr=lear_rate)

    # train the model
    for epoch in range(num_ep):

        for batch_idx, (minibatch, labels) in enumerate(train_loader):

            # rearrange data shape
            minibatch = minibatch.reshape(batch_size, 1, 784)
            labels = labels.reshape(batch_size, 1)

            model.zero_grad()
            pi, sigma, mu = model(minibatch)
            loss = mdn.mdn_loss(pi, sigma, mu, labels)
            loss.backward()
            optimizer.step()

            print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(minibatch), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

    return model
Ejemplo n.º 3
0
    # train the model
    for epoch in range(num_epochs):

        for batch_idx, (labels, minibatch) in enumerate(train_loader):

            gt = labels[0:3]

            minibatch = [[numb == minibatch[i] for numb in range(10)]
                         for i in range(batch_size)]
            minibatch = torch.FloatTensor(minibatch).unsqueeze(1)

            labels = labels.reshape(batch_size, 784)

            model.zero_grad()
            pi, sigma, mu = model(minibatch)
            loss = mdn.mdn_loss(pi, sigma, mu, labels)
            loss.backward()
            optimizer.step()

            print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(minibatch), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

            # Visualize
            if batch_idx == 0:
                samples = mdn.sample(pi, sigma, mu).int()
                images = samples[0:3]
                plot(gt, images, epoch)

    print(
        "The training is compelete! \nVisualization results have been saved in the folder 'figures'"
Ejemplo n.º 4
0
training_set = torch.cat([cluster1, cluster2, cluster3])
print('Done')

print("Initializing model... ", end='')
model = nn.Sequential(nn.Linear(input_dims, 5), nn.Tanh(),
                      mdn.MDN(5, output_dims, num_gaussians))

optimizer = optim.Adam(model.parameters())
print('Done')

print('Training model... ', end='')
sys.stdout.flush()
for epoch in range(1000):
    model.zero_grad()
    pi, sigma, mu = model(training_set[:, 0:input_dims])
    loss = mdn.mdn_loss(pi, sigma, mu, training_set[:, input_dims:])
    loss.backward()
    optimizer.step()
    if epoch % 100 == 99:
        print(f' {round(epoch/10)}%', end='')
        sys.stdout.flush()
print(' Done')

print('Generating samples... ', end='')
pi, sigma, mu = model(training_set[:, 0:input_dims])
samples = mdn.sample(pi, sigma, mu)
print('Done')

print('Saving samples.png... ', end='')
fig = plt.figure()
ax = fig.add_subplot(projection='3d')