Exemplo n.º 1
0
def train(config):

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length,
                                dataset.vocab_size, config.lstm_num_hidden,
                                config.lstm_num_layers, device)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(), config.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=config.learning_rate_step,
                                          gamma=config.learning_rate_decay)

    accuracy_train = []
    loss_train = []

    if config.samples_out_file != "STDOUT":
        samples_out_file = open(config.samples_out_file, 'w')

    epochs = config.train_steps // len(data_loader) + 1

    print(
        "Will train on {} batches in {} epochs, max {} batches/epoch.".format(
            config.train_steps, epochs, len(data_loader)))

    for epoch in range(epochs):
        data_loader_iter = iter(data_loader)

        if epoch == config.train_steps // len(data_loader):
            batches = config.train_steps % len(data_loader)
        else:
            batches = len(data_loader)

        for step in range(batches):
            batch_inputs, batch_targets = next(data_loader_iter)
            model.zero_grad()

            # Only for time measurement of step through network
            t1 = time.time()

            batch_inputs = F.one_hot(
                batch_inputs,
                num_classes=dataset.vocab_size,
            ).float().to(device)
            batch_targets = batch_targets.to(device)

            optimizer.zero_grad()

            pred, _ = model.forward(batch_inputs)
            loss = criterion(pred.transpose(2, 1), batch_targets)
            accuracy = acc(
                pred.transpose(2, 1),
                F.one_hot(batch_targets,
                          num_classes=dataset.vocab_size).float(),
                dataset.vocab_size)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=config.max_norm)
            optimizer.step()

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            scheduler.step()

            if (epoch * len(data_loader) + step + 1) % config.seval_every == 0:
                accuracy_train.append(accuracy)
                loss_train.append(loss.item())

            if (epoch * len(data_loader) + step + 1) % config.print_every == 0:
                print(
                    "[{}] Epoch: {:04d}/{:04d}, Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                    "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"), epoch + 1,
                        epochs, (epoch * len(data_loader) + step + 1),
                        config.train_steps, config.batch_size,
                        examples_per_second, accuracy, loss))

            if (epoch * len(data_loader) + step +
                    1) % config.sample_every == 0:
                with torch.no_grad():
                    codes = []

                    input_tensor = torch.zeros((1, 1, dataset.vocab_size),
                                               device=device)
                    input_tensor[0, 0,
                                 np.random.randint(0, dataset.vocab_size)] = 1

                    for i in range(config.seq_length - 1):
                        response = model.step(input_tensor)
                        logits = F.log_softmax(config.temp * response, dim=1)
                        dist = torch.distributions.one_hot_categorical.OneHotCategorical(
                            logits=logits)
                        code = dist.sample().argmax().item()
                        input_tensor *= 0
                        input_tensor[0, 0, code] = 1
                        codes.append(code)

                    string = dataset.convert_to_string(codes)
                    model.reset_stepper()

                    if config.samples_out_file != "STDOUT":
                        samples_out_file.write("Step {}: ".format(
                            epoch * len(data_loader) + step + 1) + string +
                                               "\n")
                    else:
                        print(string)

    if config.samples_out_file != "STDOUT":
        samples_out_file.close()

    if config.model_out_file != None:
        torch.save(model, config.model_out_file)

    if config.curves_out_file != None:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        fig.suptitle(
            'Training curves for Pytorch 2-layer LSTM.\nFinal loss: {:.4f}. Final accuracy: {:.4f}\nSequence length: {}, Hidden units: {}, LSTM layers: {}, Learning rate: {:.4f}'
            .format(loss_train[-1], accuracy_train[-1], config.seq_length,
                    config.lstm_num_hidden, config.lstm_num_layers,
                    config.learning_rate))
        plt.subplots_adjust(top=0.8)

        ax[0].set_title('Loss')
        ax[0].set_ylabel('Loss value')
        ax[0].set_xlabel('No of batches seen x{}'.format(config.seval_every))
        ax[0].plot(loss_train, label='Train')
        ax[0].legend()

        ax[1].set_title('Accuracy')
        ax[1].set_ylabel('Accuracy value')
        ax[1].set_xlabel('No of batches seen x{}'.format(config.seval_every))
        ax[1].plot(accuracy_train, label='Train')
        ax[1].legend()

        plt.savefig(config.curves_out_file)

    print('Done training.')
Exemplo n.º 2
0
def train(config):

    # Initialize the device which to run the model on
    #device = torch.device(config.device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)  # fixme
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    #print(dataset._char_to_ix) vocabulary order changes, but batches are same sentence examples with the seeds earlier.

    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length,
                                dataset.vocab_size, config.lstm_num_hidden,
                                config.lstm_num_layers, config.device)  # fixme

    device = model.device
    model = model.to(device)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=config.learning_rate)
    print("Len dataset:", len(dataset))
    print("Amount of steps for dataset:", len(dataset) / config.batch_size)

    current_step = 0
    not_max = True

    list_train_acc = []
    list_train_loss = []
    acc_average = []
    loss_average = []

    file = open("sentences.txt", 'w', encoding='utf-8')
    '''
    file_greedy = open("sentences_greedy.txt",'w',encoding='utf-8')
    file_tmp_05 = open("sentences_tmp_05.txt", 'w', encoding='utf-8')
    file_tmp_1 = open("sentences_tmp_1.txt", 'w', encoding='utf-8')
    file_tmp_2 = open("sentences_tmp_2.txt", 'w', encoding='utf-8')
    '''

    while not_max:

        for (batch_inputs, batch_targets) in data_loader:

            # Only for time measurement of step through network
            t1 = time.time()

            #######################################################
            # Add more code here ...

            #List of indices from word to ID, that is in dataset for embedding
            #Embedding lookup
            embed = model.embed  #Embeding shape(dataset.vocab_size, config.lstm_num_hidden)

            #Preprocess input to embeddings to give to LSTM all at once
            all_embed = []
            #sentence = []
            for batch_letter in batch_inputs:
                batch_letter_to = batch_letter.to(
                    device)  #torch.tensor(batch_letter,device = device)
                embedding = embed(batch_letter_to)
                all_embed.append(embedding)

                #sentence.append(batch_letter_to[0].item())
            all_embed = torch.stack(all_embed)

            #Print first example sentence of batch along with target
            #print(dataset.convert_to_string(sentence))
            #sentence = []
            #for batch_letter in batch_targets:
            #    sentence.append(batch_letter[0].item())
            #print(dataset.convert_to_string(sentence))

            all_embed = all_embed.to(device)
            outputs = model(
                all_embed
            )  #[30,64,vocab_size] 87 last dimension for fairy tails

            #######################################################

            #loss = np.inf   # fixme
            #accuracy = 0.0  # fixme

            #For loss: ensuring that the prediction dim are batchsize x vocab_size x sequence length and targets: batchsize x sequence length
            batch_first_output = outputs.transpose(0, 1).transpose(1, 2)
            batch_targets = torch.stack(batch_targets).to(device)
            loss = criterion(batch_first_output, torch.t(batch_targets))

            #Backpropagate
            model.zero_grad()
            loss.backward()
            loss = loss.item()
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          max_norm=config.max_norm)
            optimizer.step()

            #Accuracy
            number_predictions = torch.argmax(outputs, dim=2)
            result = number_predictions == batch_targets
            accuracy = result.sum().item() / (batch_targets.shape[0] *
                                              batch_targets.shape[1])
            ''''
            #Generate sentences for all settings on every step
            sentence_id = model.generate_sentence(config.gsen_length, -1)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_greedy.write( (str(current_step) + ": " + sentence + "\n"))

            sentence_id = model.generate_sentence(config.gsen_length, 0.5)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_tmp_05.write( (str(current_step) + ": " + sentence + "\n"))

            sentence_id = model.generate_sentence(config.gsen_length, 1)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_tmp_1.write( (str(current_step) + ": " + sentence + "\n"))

            sentence_id = model.generate_sentence(config.gsen_length, 2)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_tmp_2.write( (str(current_step) + ": " + sentence + "\n"))
            '''

            if config.measure_type == 2:
                acc_average.append(accuracy)
                loss_average.append(loss)

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if current_step % config.print_every == 0:

                # Average accuracy and loss over the last print every step (5 by default)
                if config.measure_type == 2:
                    accuracy = sum(acc_average) / config.print_every
                    loss = sum(loss_average) / config.print_every
                    acc_average = []
                    loss_average = []

                # Either accuracy and loss on the print every interval or the average of that interval as stated above
                list_train_acc.append(accuracy)
                list_train_loss.append(loss)

                print(
                    "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                    "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"),
                        current_step, config.train_steps, config.batch_size,
                        examples_per_second, accuracy, loss))
            elif config.measure_type == 0:
                # Track accuracy and loss for every step
                list_train_acc.append(accuracy)
                list_train_loss.append(loss)

            if current_step % config.sample_every == 0:
                # Generate sentence
                sentence_id = model.generate_sentence(config.gsen_length,
                                                      config.temperature)
                sentence = dataset.convert_to_string(sentence_id)
                print(sentence)
                file.write((str(current_step) + ": " + sentence + "\n"))

            if current_step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                not_max = False
                break

            current_step += 1

    # Close the file and make sure sentences en measures are saved
    file.close()
    pickle.dump((list_train_acc, list_train_loss),
                open("loss_and_train.p", "wb"))

    #Plot
    print(len(list_train_acc))

    if config.measure_type == 0:
        eval_steps = list(range(config.train_steps + 1))  # Every step Acc
    else:  #
        eval_steps = list(
            range(0, config.train_steps + config.print_every,
                  config.print_every))

    if config.measure_type == 2:
        plt.plot(eval_steps[:-1], list_train_acc[1:], label="Train accuracy")
    else:
        plt.plot(eval_steps, list_train_acc, label="Train accuracy")

    plt.xlabel("Step")
    plt.ylabel("Accuracy")
    plt.title("Training accuracy LSTM", fontsize=18, fontweight="bold")
    plt.legend()
    # plt.savefig('accuracies.png', bbox_inches='tight')
    plt.show()

    if config.measure_type == 2:
        plt.plot(eval_steps[:-1], list_train_loss[1:], label="Train loss")
    else:
        plt.plot(eval_steps, list_train_loss, label="Train loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Training loss LSTM", fontsize=18, fontweight="bold")
    plt.legend()
    # plt.savefig('loss.png', bbox_inches='tight')
    plt.show()
    print('Done training.')