Exemplo n.º 1
0
def train(train_loader,
          test_loader,
          gradient_clipping=1,
          hidden_state_size=10,
          lr=0.001,
          epochs=100,
          classify=True):
    model = EncoderDecoder(input_size=28, hidden_size=hidden_state_size, output_size=28, labels_num=10) if not classify \
        else EncoderDecoder(input_size=28, hidden_size=hidden_state_size, output_size=28, is_prediction=True,
                            labels_num=10)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_name = "mse"
    min_loss = float("inf")
    task_name = "classify" if classify else "reconstruct"
    validation_losses = []
    validation_accuracies = []
    tensorboard_writer = init_writer(results_path, lr, classify,
                                     hidden_state_size, epochs)
    for epoch in range(1, epochs):
        total_loss = 0
        total_batches = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(device)
            target = target.to(device)
            # data_sequential = data # turn each image to vector sized 784
            data_sequential = data.view(data.shape[0], 28, 28)
            optimizer.zero_grad()
            if classify:
                resconstucted_batch, batch_pred_probs = model(data_sequential)
                loss = model.loss(data_sequential, resconstucted_batch, target,
                                  batch_pred_probs)
            else:
                resconstucted_batch = model(data_sequential)
                loss = model.loss(data_sequential, resconstucted_batch)
            total_loss += loss.item()
            loss.backward()
            if gradient_clipping:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         max_norm=gradient_clipping)
            optimizer.step()
            total_batches += 1

        epoch_loss = total_loss / total_batches

        tensorboard_writer.add_scalar('train_loss', epoch_loss, epoch)
        print(f'Train Epoch: {epoch} \t loss: {epoch_loss}')

        validation_loss = validation(model, test_loader, validation_losses,
                                     device, classify, validation_accuracies,
                                     tensorboard_writer, epoch)
        model.train()

        if epoch % 5 == 0 or validation_loss < min_loss:
            file_name = f"ae_toy_{loss_name}_lr={lr}_hidden_size={hidden_state_size}_epoch={epoch}_gradient_clipping={gradient_clipping}.pt"
            path = os.path.join(results_path, "saved_models", "MNIST_task",
                                task_name, file_name)
            torch.save(model, path)

        min_loss = min(validation_loss, min_loss)

    plot_validation_loss(epochs, gradient_clipping, lr, loss_name,
                         validation_losses, hidden_state_size, task_name)
    if classify:
        plot_validation_acc(epochs, gradient_clipping, lr, loss_name,
                            validation_accuracies, hidden_state_size,
                            task_name)
Exemplo n.º 2
0
def train(train_loader,
          test_loader,
          gradient_clipping=1,
          hidden_state_size=10,
          lr=0.001,
          epochs=3000,
          is_prediction=False):
    model = EncoderDecoder(input_size=1, hidden_size=hidden_state_size, output_size=1,
                           labels_num=1) if not is_prediction \
        else EncoderDecoder(input_size=1, hidden_size=hidden_state_size, output_size=1, is_prediction=True,
                            labels_num=1, is_snp=True)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_name = "mse"
    min_loss = float("inf")
    task_name = "classify" if is_prediction else "reconstruct"
    validation_losses = []
    tensorboard_writer = init_writer(lr, is_prediction, hidden_state_size,
                                     epochs, task_name)
    for epoch in range(1, epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):

            data_sequential = (data.view(data.shape[0], data.shape[1],
                                         1)).to(device)
            target = target.to(device)
            optimizer.zero_grad()
            if is_prediction:
                resconstucted_batch, batch_preds = model(data_sequential)
                batch_preds = batch_preds.view(batch_preds.shape[0],
                                               batch_preds.shape[1])
                loss = model.loss(data_sequential, resconstucted_batch, target,
                                  batch_preds)
            else:
                resconstucted_batch = model(data_sequential)
                loss = model.loss(data_sequential, resconstucted_batch)
            total_loss += loss.item()
            loss.backward()
            if gradient_clipping:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         max_norm=gradient_clipping)
            optimizer.step()

        epoch_loss = total_loss / len(train_loader)
        tensorboard_writer.add_scalar('train_loss', epoch_loss, epoch)
        print(f'Train Epoch: {epoch} \t loss: {epoch_loss}')

        validation_loss = validation(model, test_loader, validation_losses,
                                     device, is_prediction, tensorboard_writer,
                                     epoch)

        if epoch % 5 == 0 or validation_loss < min_loss:
            file_name = f"ae_s&p500_{loss_name}_lr={lr}_hidden_size={hidden_state_size}_epoch={epoch}_gradient_clipping={gradient_clipping}.pt"
            path = os.path.join(results_path, "saved_models", "s&p500_task",
                                task_name, file_name)
            torch.save(model, path)

        min_loss = min(validation_loss, min_loss)

    plot_validation_loss(epochs, gradient_clipping, lr, loss_name,
                         validation_losses, hidden_state_size, task_name)