コード例 #1
0
def train(config, input_length):

    # Initialize the model that we are going to use
    model = VanillaRNN(input_length, config.input_dim, config.num_hidden,
                       config.num_classes, config.batch_size)  # fixme

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Initialize the dataset and data loader (leave the +1)
    dataset = PalindromeDataset(input_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()  # fixme
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config.learning_rate)  # fixme

    losses = []
    accuracies = []
    loss = 0.0

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Add more code here ...
        optimizer.zero_grad()
        batch_inputs, batch_targets = batch_inputs.to(
            device), batch_targets.to(device)

        outputs = model(batch_inputs)
        loss = criterion(outputs, batch_targets)
        loss.backward()
        optimizer.step()

        # the following line is to deal with exploding gradients
        torch.nn.utils.clip_grad_norm(model.parameters(),
                                      max_norm=config.max_norm)

        # Add more code here ...

        loss += loss.item()  # fixme
        accu = 0.0  # fixme

        if step % 10 == 0:
            # print acuracy/loss here
            print('[step: %5d] loss: %.4f' % (step, loss / 10))
            losses.append(loss / 10)
            loss = 0.0
            accu = accuracy(outputs, batch_targets)
            accuracies.append(accu)
            print('Accuracy on training dataset: %.3f %%' % (accu))

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')

    return model, losses, accuracies
コード例 #2
0
def train(config):

    assert config.model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize the model that we are going to use
    if config.model_type == 'RNN':
        model = VanillaRNN(seq_length=config.input_length,
                           input_dim=config.input_dim,
                           num_hidden=config.num_hidden,
                           num_classes=config.num_classes,
                           batch_size=config.batch_size,
                           device=device)
    elif config.model_type == 'LSTM':
        model = LSTM(seq_length=config.input_length,
                     input_dim=config.input_dim,
                     num_hidden=config.num_hidden,
                     num_classes=config.num_classes,
                     batch_size=config.batch_size,
                     device=device)

    model.to(device)
    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(params=model.parameters(),
                                    lr=config.learning_rate)

    # evaluation metrics
    results = []

    print_setting(config)

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)

        s_inputs = batch_inputs.shape
        s_targets = batch_targets.shape

        #forward pass
        predictions = model.forward(batch_inputs)

        #compute loss
        loss = criterion(predictions, batch_targets)

        #backward pass & updates
        # set gradients to zero
        optimizer.zero_grad()
        loss.backward()
        ############################################################################
        # QUESTION: what happens here and why?
        # Prevents exploding gradients by rescaling to a limit specified by config.max_norm
        # Forcing gradients to be within a certain norm to ensure reasonable updates
        ############################################################################
        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       max_norm=config.max_norm)
        ############################################################################

        optimizer.step()

        accuracy = (predictions.argmax(dim=1)
                    == batch_targets).sum().float() / (config.batch_size)

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)

        if step % config.eval_freq == 0:

            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss))

            #l = loss.float().item()
            results.append([step, accuracy.item(), loss.float().item()])

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training. \n')

    return results
コード例 #3
0
ファイル: train.py プロジェクト: jamie0725/Deep-Learning
def train(config):

    assert config.model_type in ('RNN', 'LSTM')

    # Print all configs to confirm parameter settings
    print_flags()

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize the model that we are going to use
    if config.model_type == 'RNN':
        model = VanillaRNN(config.input_length, config.input_dim,
                           config.num_hidden, config.num_classes,
                           config.batch_size, device)
    else:
        model = LSTM(config.input_length, config.input_dim, config.num_hidden,
                     config.num_classes, config.batch_size, device)
    model.to(device)

    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(),
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay,
                              momentum=config.momentum)

    # Store some measures
    best_acc = 0.
    los = list()
    iteration = list()
    tmp_acc = list()
    acc = list()

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        optimizer.zero_grad()
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)
        pred = model(batch_inputs)
        accuracy = compute_accuracy(pred, batch_targets)
        tmp_acc.append(accuracy)
        loss = criterion(pred, batch_targets)
        loss.backward()
        ############################################################################
        # QUESTION: what happens here and why?
        ############################################################################
        torch.nn.utils.clip_grad_norm(model.parameters(),
                                      max_norm=config.max_norm)
        ############################################################################
        optimizer.step()

        # Just for time measurement
        t2 = time.time()
        if not float(t2 - t1) == 0:
            examples_per_second = config.batch_size / float(t2 - t1)

        if step % 10 == 0:

            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss))
            iteration.append(step)
            acc.append(accuracy)
            los.append(loss)
            if accuracy > best_acc:
                best_acc = accuracy

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')
    tmp_acc.sort(reverse=True)
    avg_acc = sum(tmp_acc[:50]) / 50
    print('Average of 50 best accuracies: {}'.format(avg_acc))
    with open('result/{}_acc.txt'.format(config.model_type), 'a') as file:
        file.write('{} {}\n'.format(config.input_length, avg_acc))
        file.close()
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].plot(iteration, acc)
    axs[0].set_xlabel('Iteration')
    axs[0].set_ylabel('Accuracy')
    axs[1].plot(iteration, los)
    axs[1].set_xlabel('Iteration')
    axs[1].set_ylabel('Loss')
    fig.tight_layout()
    plt.show()
コード例 #4
0
def train(config):

    assert config.model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
    print('Currently using: ', device)
    # Initialize the model that we are going to use
    input_length = config.input_length
    input_dim = config.input_dim
    num_classes = config.num_classes
    num_hidden = config.num_hidden
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    
    if config.model_type == 'RNN':
    
        model = VanillaRNN(input_length, input_dim, num_hidden, num_classes
                           , batch_size, device).double()
        
    if config.model_type == 'LSTM':
        model = LSTM(input_length, input_dim, num_hidden, num_classes, batch_size, device).double()
    
    
    
    model = model.to(device)
    
    
    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(inp_len+1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()  # fixme
    optimizer = torch.optim.RMSprop(model.parameters(), lr = learning_rate)  # fixme
    accuracy_list = []
    loss_list = []

## first 100 steps are to generate the test set
    for step, (batch_inputs, batch_targets) in enumerate(data_loader):
        # Only for time measurement of step through network
        t1 = time.time()

        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)

        output = model.forward(batch_inputs.transpose(0,1).double())

        optimizer.zero_grad()
        
        output_indices = torch.argmax(output.transpose(0,1), dim=0)
        loss_for_backward = criterion(output,batch_targets).to(device)
        loss_for_backward.backward()
        
        ############################################################################
        # QUESTION: what happens here and why?
        ############################################################################
        torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=config.max_norm)
        ############################################################################

        #print(output.shape)
        #print(batch_targets.shape)
        
        optimizer.step()
        
        #loss = criterion.forward(output, batch_targets)
        
        correct_indices = output_indices == batch_targets
        
        
        
        
        #if step == 4000:
        #    return correct_indices, output_indices, batch_targets, batch_inputs
        accuracy = int(sum(correct_indices))/int(len(correct_indices))

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size/float(t2-t1)

        if step % 10 == 0:

            print("[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                  "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss_for_backward
            ))
            accuracy_list.append(accuracy)
            loss_list.append(loss_for_backward)

        if step == config.train_steps or (len(accuracy_list) > 10 and (sum(accuracy_list[-3:])
        /len(accuracy_list[-3:])) == 1.0):
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')
    line = ' '.join((str(config.model_type),'Palindrome length:',str(input_length),'Accuracy:',str(accuracy_list),'Loss', str(loss_list)))
    with open('LSTMMMMM.txt', 'a') as file:
                          file.write(line + '\n')
コード例 #5
0
ファイル: train.py プロジェクト: PhilLint/Deep-Learning
def train(config):

    assert config.model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize the model that we are going to use
    if config.model_type == "RNN":
        model = VanillaRNN(seq_length=config.input_length,
                           input_dim=config.input_dim,
                           num_hidden=config.num_hidden,
                           batch_size=config.batch_size,
                           num_classes=config.num_classes,
                           device=device)

    elif config.model_type == "LSTM":
        model = LSTM(seq_length=config.input_length,
                     input_dim=config.input_dim,
                     num_hidden=config.num_hidden,
                     num_classes=config.num_classes,
                     device=device,
                     batch_size=config.batch_size)
    # send model to device
    model.to(device)
    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=config.learning_rate)

    # track training statistics
    train_accuracies = []
    train_losses = []

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        # batch inputs  to device for cuda
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)

        # convert input batches to tensors on device
        ínput_sequences = torch.tensor(batch_inputs,
                                       dtype=torch.float,
                                       device=device)
        targets = torch.tensor(batch_targets, dtype=torch.long, device=device)

        #print(ínput_sequences)
        #print(targets)

        # Backward pass
        # reset gradients
        optimizer.zero_grad()

        # Forward pass
        # Debugging
        # predict classes for input batches
        # a = ínput_sequences[:, 0].unsqueeze(1)
        # print(ínput_sequences.size())
        # print(a.size())
        # break

        # predict input sequences
        predictions = model.forward(ínput_sequences)
        # accuracy
        accuracy = torch.div(
            torch.sum(targets == predictions.argmax(dim=1)).to(torch.float),
            config.batch_size)
        # print(accuracy)
        # backpropagate loss
        # compute loss per batch
        loss = criterion(predictions, targets)
        loss.backward()

        ############################################################################
        # QUESTION: what happens here and why?
        # --> # ANSWER: Gradients are reinforced at each layer. Thus, very large gradients can appear. This leads to
        #  learning problems. Cutting the gradients to a limit overcomes that issue.
        ############################################################################
        torch.nn.utils.clip_grad_norm(model.parameters(),
                                      max_norm=config.max_norm)
        ############################################################################
        # update weights according to optimizer
        optimizer.step()

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)
        # save stats for each step
        train_accuracies.append(accuracy)
        train_losses.append(loss)

        if step % 10 == 0:

            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss))

            # If the last 50 accuracies are already 1 (avg=1), stop the training, as convergence is reached and unnecessary
            # computations dont have to be done
            avg_accuracies = np.sum(train_accuracies[-50:]) / 50
            print(avg_accuracies)
            if avg_accuracies == 1:
                print(
                    "\nTraining finished for length: {} after {} steps".format(
                        config.input_length, step))
                print("Avg Accuracy : {:.3f}".format(avg_accuracies))
                break

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')

    return max(train_accuracies), step
コード例 #6
0
ファイル: train.py プロジェクト: EliasKassapis/Deep-Learning
def train(config,n_run):

    assert config.model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Train on T-1 first digits
    config.input_length = config.input_length - 1

    # Initialize the model that we are going to use
    if config.model_type == 'RNN':
        model = VanillaRNN(config.input_length, config.input_dim, config.num_hidden, config.num_classes, config.batch_size, device=device)
    elif config.model_type == 'LSTM':
        model = LSTM(config.input_length, config.input_dim, config.num_hidden, config.num_classes, config.batch_size, device=device)


    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length+1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)

    model.to(device)

    train_loss = []
    train_acc = []
    t_loss = []
    t_acc = []

    #Convergence condition
    eps = 1e-6

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Clear stored gradient
        model.zero_grad()

        # Only for time measurement of step through network
        t1 = time.time()

        # Add more code here ...

        #Convert inputs and labels into tensors
        x = torch.tensor(batch_inputs, device=device)
        y = torch.tensor(batch_targets,device=device)


        #Forward pass
        pred = model.forward(x)
        loss = criterion(pred, y)
        t_loss.append(loss.item())
        optimizer.zero_grad()

        #Backward pass
        loss.backward()

        ############################################################################
        # QUESTION: what happens here and why?

        # ANSWER : the function torch.nn.utils.clip_grad_norm() is used to prevent
        # exploding gradients by ‘clipping’ the norm of the gradients, to restrain
        # the gradient values to a certain threshold. This essentially acts as a
        # limit to the size of the updates of the parameters of every layer, ensuring
        # that the parameter values don't change too much from their previous values.

        ############################################################################
        torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=config.max_norm)
        ############################################################################

        # Add more code here ...

        optimizer.step()
        accuracy = get_accuracy(pred,y, config.batch_size)
        t_acc.append(accuracy.item())

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size/float(t2-t1)

        if step % 1000 == 0:

            print("[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                  "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss
            ))

        if step % 100 == 0:
            #Get loss and accuracy averages over 100 steps
            train_loss.append(np.mean(t_loss))
            train_acc.append(np.mean(t_acc))
            t_loss = []
            t_acc = []

            if step > 0 and abs(train_loss[-1] - train_loss[-2]) < eps:
                break


        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break


    print('\nDone training.\n')
    #
    #Save trained model and results
    if config.model_type == 'RNN':
        #save model
        torch.save(model, "./Results/RNN/" + str(config.input_length) + "_RNN_model")
        #save train accuracy and loss
        np.save("./Results/RNN/" + str(config.input_length) + "_RNN_accuracy", train_acc)
        np.save("./Results/RNN/" + str(config.input_length) + "_RNN_loss", train_loss)

        # #save model ####################################################################### For SURFsara
        # torch.save(model, str(config.input_length+1) + "_RNN_model_" + str(n_run))
        # #save train accuracy and loss
        # np.save(str(config.input_length+1) + "_RNN_accuracy_" + str(n_run), train_acc)
        # np.save(str(config.input_length+1) + "_RNN_loss_" + str(n_run), train_loss)

    elif config.model_type == 'LSTM':
        #save model
        torch.save(model, "./Results/LSTM/" + str(config.input_length) + "_LSTM_model")
        #save train accuracy and loss
        np.save("./Results/LSTM/" + str(config.input_length) + "_LSTM_accuracy", train_acc)
        np.save("./Results/LSTM/" + str(config.input_length) + "_LSTM_loss", train_loss)
コード例 #7
0
def train(config):
    

    np.random.seed(42)
    torch.manual_seed(42)
    assert config.model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    print(device)

    # Initialize the model that we are going to use
    if config.model_type=="RNN":
        
        print("Training VanillaRNN")
        print()
        model = VanillaRNN(config.input_length, config.input_dim,\
                                config.num_hidden, config.num_classes, config.batch_size, config.device)  # fixme
    else:
        print("Training LSTM")
        print()
        model = LSTM(config.input_length, config.input_dim,\
                                config.num_hidden, config.num_classes, config.batch_size, config.device)

    model = model.to(device)
    
    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length+1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)
    
    # Setup the loss and optimizer
    criterion =  nn.CrossEntropyLoss()  #fixme
    if config.optimizer=="adam":
        optimizer = optim.Adam(model.parameters(), lr = config.learning_rate) # fixme
    else: 
        optimizer = optim.RMSprop(model.parameters(), lr = config.learning_rate)   
    pl_loss =[]
    average_loss =[]
    acc =[]
    
    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()
        
        batch_targets = torch.LongTensor(batch_targets)
        batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)
        
        
        # zero the parameter gradients
        model.zero_grad()
        
        # Add more code here ...
        output = model(batch_inputs)

        out_loss = criterion(output, batch_targets)
        out_loss.backward()
        
        ############################################################################
        # QUESTION: what happens here and why?
        # ANSWER: helps prevent the exploding gradient problem in RNNs / LSTMs.
        ############################################################################
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.max_norm)
        ############################################################################
        optimizer.step()
        
        # Add more code here ...

        loss = out_loss.item()   # fixme
        # get argmax
        softmax = torch.nn.Softmax(dim=1)
        predictions = torch.argmax(softmax(output), dim=1)
        predictions = config.batch_size-len(torch.nonzero(predictions - batch_targets))
        accuracy = predictions/config.batch_size              
        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size/float(t2-t1)
        
        pl_loss.append(loss)
        average_loss.append(np.mean(pl_loss[:-100:-1]))
        acc.append(accuracy)
        
        if step % 10 == 0:

            print("[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                  "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss
            ))

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

        # if step%100==0:
        #     # save training loss
        #     plt.plot(pl_loss,'r-', label="Batch loss", alpha=0.5)
        #     plt.plot(average_loss,'g-', label="Average loss", alpha=0.5)
        #     plt.legend()
        #     plt.xlabel("Iterations")
        #     plt.ylabel("Loss")  
        #     plt.title("Training Loss")
        #     plt.grid(True)
        #     # plt.show()
        #     plt.savefig(config.optimizer+"_loss_"+config.model_type+"_"+str(config.input_length)+".png")
        #     plt.close()
    ################################training##################################################
    # plt.plot(acc,'g-', alpha=0.5)
    # plt.xlabel("Iterations")
    # plt.ylabel("Accuracy")
    # plt.title("Train Accuracy")
    # plt.grid(True)
    # plt.savefig("accuracy_"+config.sampling+"_"+str(config.temp)+".png")
    #  plt.close()
    # fl = config.optimizer+"_acc_"+config.model_type+"_"+str(config.input_length)
   
    
    # np.savez(fl, acc=acc)
    print('Done training.')
コード例 #8
0
    train_steps = config.train_steps
    max_norm = config.max_norm

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    acc_list = []
    loss_list = []
    epoch_list = []

    # Run experiment 5 times for significant results
    for _ in range(3):
        # Initialize the model that we are going to use
        if model_type == 'RNN':
            model = VanillaRNN(input_length, input_dim, num_hidden, num_classes, batch_size, device=device)
            model.to(device)
        elif model_type =='LSTM':
            model = LSTM(input_length, input_dim, num_hidden, num_classes, batch_size, device=device)
            model.to(device)

        # Initialize the dataset and data loader (note the +1)
        dataset = PalindromeDataset(input_length+1)
        data_loader = DataLoader(dataset, batch_size, num_workers=0)

        # Setup the loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, alpha=0.99,
                                        eps=1e-08, weight_decay=0, momentum=0,
                                        centered=False
                                        )
        print('start training')