Example #1
0
def train(config):
    
    
    # Initialize the device which to run the model on
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)   # fixme
    data_loader = DataLoader(dataset, batch_size = config.batch_size, shuffle=True, num_workers=1)
    vocab_size = dataset.vocab_size
    # char2i = dataset._char_to_ix
    # i2char = dataset._ix_to_char
    # ----------------------------------------
    
    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length, vocab_size, \
                                config.lstm_num_hidden, config.lstm_num_layers, device)  # fixme
    model.to(device)

    # Setup the loss and optimizer
    criterion = nn.NLLLoss()  # fixme
    optimizer = optim.RMSprop(model.parameters(), lr = config.learning_rate)  # fixme
    logSoftmax = nn.LogSoftmax(dim=2)
    
    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, \
                  step_size=config.learning_rate_step, gamma=config.learning_rate_decay)
    step = 1
    
    if config.resume:
        if os.path.isfile(config.resume):
            print("Loading checkpoint '{}'".format(config.resume))
            checkpoint = torch.load(config.resume)
            step = checkpoint['step']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            print("Checkpoint loaded '{}', steps {}".format(config.resume, checkpoint['step']))

    if not os.path.isdir(config.summary_path):
            os.makedirs(config.summary_path)

    if config.sampling =="greedy":
        
        f = open(os.path.join(config.summary_path,"sampled_"+config.sampling+".txt"), "w+")
    else:
        f = open(os.path.join(config.summary_path,"sampled_"+config.sampling+"_"+str(config.temp)+".txt"), "w+")



    
   
    best_accuracy = 0.0
    pl_loss =[]
    average_loss =[]
    acc =[]

    for epochs in range(30):

        if step == config.train_steps:
            print('Done training.')
            break

        for (batch_inputs, batch_targets) in data_loader:

            if config.batch_size!=batch_inputs.size()[0]:
                print("batch mismatch")
                break

            # Only for time measurement of step through network
            t1 = time.time()
            model.hidden = model.init_hidden(config.batch_size)

            model.zero_grad()
            #######################################################
            # Add more code here ...
            
            #convert batch inputs to one-hot vector
            batch_inputs= torch.zeros(config.batch_size, config.seq_length, vocab_size).scatter_(2,batch_inputs.unsqueeze(-1),1.0)
            
            batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)

            predictions, _ = model(batch_inputs)
            if config.sampling=="greedy":
                predictions = logSoftmax(predictions)
            else:
                predictions = logSoftmax(predictions/config.temp)

            loss = criterion(predictions.transpose(2,1), batch_targets)   # fixme

            _, predictions = torch.max(predictions, dim=2, keepdim=True)
            predictions = (predictions.squeeze(-1) == batch_targets).float()
            accuracy = torch.mean(predictions)
            
            
            
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=config.max_norm)
            
            optimizer.step()
            lr_scheduler.step()

            #######################################################

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size/float(t2-t1)
            pl_loss.append(loss.item())
            average_loss.append(np.mean(pl_loss[:-100:-1]))
            acc.append(accuracy)


            if step % config.print_every == 0:

                print("[{}] Train Step {}/{}, Batch Size = {}, Examples/Sec = {:.2f}, "
                    "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                        config.train_steps, config.batch_size, examples_per_second,
                        accuracy, loss.item()
                ))
                
                

            if step % config.sample_every == 0:
                               
                model.eval()
               
                with torch.no_grad():
                   char_ix = generate_sample(model, vocab_size, config.seq_length, device, config)
                   sentence = dataset.convert_to_string(char_ix) 
                           
            
                f.write("--------------"+str(step)+"----------------\n")
                f.write(sentence+"\n")
                print(sentence)
                print()
                model.train()
                # ###########################################################################
                # save training loss
                plt.plot(pl_loss,'r-', label="Batch loss", alpha=0.5)
                plt.plot(average_loss,'g-', label="Average loss", alpha=0.5)
                plt.legend()
                plt.xlabel("Iterations")
                plt.ylabel("Loss")  
                plt.title("Training Loss")
                plt.grid(True)
                # plt.show()
                if config.sampling == "greedy":
                    plt.savefig("loss_"+config.sampling+".png")
                else:
                    plt.savefig("loss_"+config.sampling+"_"+str(config.temp)+".png")

                plt.close()
                ################################training##################################################
                plt.plot(acc,'g-', alpha=0.5)
                plt.xlabel("Iterations")
                plt.ylabel("Accuracy")
                plt.title("Train Accuracy")
                plt.grid(True)
                if config.sampling == "greedy":
                    plt.savefig("accuracy_"+config.sampling+".png")
                else:
                    plt.savefig("accuracy_"+config.sampling+"_"+str(config.temp)+".png")
                plt.close()

            if step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                break
            
            step+=1
            
        save_checkpoint({
            'epoch': epochs + 1,
            'step': step,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler':lr_scheduler.state_dict(),
            'accuracy': accuracy
                }, config)
        
    f.close()
Example #2
0
with torch.no_grad():
    # Generate a initial sequence
    # text = "The old woman had only pretend"
    text = "Then first came two white dove"
    # text = "At first Rapunzel was terribly"
    print(text)
    init_seq = torch.tensor(dataset.convert_to_index(text))

    # init_seq = t
    init_seq = torch.zeros(config.batch_size, config.seq_length,
                           vocab_size).scatter_(
                               2,
                               torch.unsqueeze(torch.unsqueeze(init_seq, 0),
                                               2), 1).to(device)

    model.hidden = model.init_hidden(config.batch_size)

    # send the first sequence
    predictions, hidden = model(init_seq)

    _, idx = torch.max(predictions, dim=2)

    idx = idx.squeeze(0)[-1]
    idx = idx.reshape(1, 1, 1)
    chars_ix = [idx.item()]
    in_char = torch.zeros(1, 1, vocab_size, device=device).scatter_(2, idx, 1)

    # model.hidden = hidden
    in_char = in_char.to(device)

    for i in range(config.generate_length):
Example #3
0
def train(config):

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)
    data_loader = DataLoader(dataset,
                             config.batch_size,
                             num_workers=1,
                             drop_last=True)
    vocab_size = dataset.vocab_size

    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length,
                                vocab_size, config.lstm_num_hidden,
                                config.lstm_num_layers, config.device)
    model = model.to(device)

    print(model)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    # if pickle file is available, load steps and use index -1 to get last step + get lists of values, to continue training
    # where we left off
    if os.path.isfile("steps.p"):
        print('Pre-trained model available...')
        print('Resuming training...')

        # load lists
        step_intervals = pickle.load(open("steps.p", "rb"))
        all_sentences = pickle.load(open("sentences.p", "rb"))
        accuracy_list = pickle.load(open("accuracies.p", "rb"))
        loss_list = pickle.load(open("loss.p", "rb"))
        model_info = pickle.load(open("model_info.p", "rb"))

        # start where we left off
        all_steps = step_intervals[-1]

        # load model
        Modelname = 'TrainIntervalModel' + model_info[0] + 'acc:' + model_info[
            1] + '.pt'
        model = torch.load(Modelname)
        model = model.to(device)

    # otherwise start training from a clean slate
    else:
        print('No pre-trained model available...')
        print('Initializing training...')

        # create lists to keep track of data while training
        all_sentences = []
        step_intervals = []
        accuracy_list = []
        loss_list = []

        # initialize total step counter
        all_steps = 0

        # initialize optimizer with starting learning rate
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=config.learning_rate)

    # initialize optimizer with previous learning rate. (extract from pickle then use scheduler)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.learning_rate_step,
        gamma=config.learning_rate_decay)

    # since the nested for loop stops looping after a complete iteration through the data_loader, add for loop for epochs
    for epoch in range(config.epochs):
        print(model)
        for step, (batch_inputs, batch_targets) in enumerate(data_loader):

            # Only for time measurement of step through network
            t1 = time.time()

            # apply scheduler
            scheduler.step()

            # create 2D tensor instead of list of 1D tensors
            #batch_inputs = torch.stack(batch_inputs)
            batch_inputs = batch_inputs.to(device)

            h, c = model.init_hidden()
            out, (h, c) = model(batch_inputs, h, c)

            # transpose to match cross entropy input dimensions
            out.transpose_(1, 2)

            batch_targets = batch_targets.to(device)

            #######################################################
            # Add more code here ...
            #######################################################

            loss = criterion(out, batch_targets)

            max = torch.argmax(out, dim=1)
            correct = (max == batch_targets)
            accuracy = torch.sum(
                correct).item() / correct.size()[0] / correct.size()[1]

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if step % config.print_every == 0:

                print(
                    "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                    "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                        int(config.train_steps), config.batch_size,
                        examples_per_second, accuracy, loss))

            if all_steps % config.sample_every == 0:

                ###############################
                # Generate generated sequence #
                ###############################

                # do not keep track of gradients during model evaluation
                with torch.no_grad():

                    # create random character to start sentence with
                    random_input = torch.randint(0,
                                                 vocab_size,
                                                 (config.batch_size, ),
                                                 dtype=torch.long).view(-1, 1)
                    x_input = random_input.to(device)

                    # initialize hidden state and cell state
                    h, c = model.init_hidden()
                    h = h.to(device)
                    c = c.to(device)

                    sentences = x_input

                    # loop through sequence length to set generated output as input for next sequence
                    for i in range(config.seq_length):

                        # get randomly generated sentence
                        out, (h, c) = model(x_input, h, c)

                        ####################
                        # Temperature here #
                        ####################

                        # check whether user wants to apply temperature sampling
                        if config.temperature:

                            # apply temperature sampling
                            out = out / config.tempvalue
                            out = F.softmax(out, dim=2)

                            # create a torch distribution of the calculated softmax probabilities and sample from that distribution
                            distribution = torch.distributions.categorical.Categorical(
                                out.view(config.batch_size, vocab_size))
                            out = distribution.sample().view(-1, 1)

                        # check whether user wants to apply greedy sampling
                        else:
                            # load new datapoint by taking the predicted previous letter using greedy approach
                            out = torch.argmax(out, dim=2)

                        # append generated character to total sentence
                        sentences = torch.cat((sentences, out), 1)
                        x_input = out

                    # pick a random sentence (from the batch of created sentences)
                    index = np.random.randint(0, config.batch_size, 1)
                    sentence = sentences[index, :]

                    # squeeze sentence into 1D
                    sentence = sentence.view(-1).cpu()

                    # print sentence
                    print(dataset.convert_to_string(sentence.data.numpy()))

                    # save sentence
                    all_sentences.append(sentence.data.numpy())

                    ##########################
                    # Save loss and accuracy #
                    ##########################

                    # save loss value
                    loss = loss.cpu()
                    loss_list.append(loss.data.numpy())

                    # save accuracy value
                    accuracy_list.append(accuracy)

                    # save step interval
                    step_intervals.append(all_steps)

            if step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                break

            # counter of total amounts of steps (keep track over multiple training sessions)
            all_steps += 1

        if config.savefiles:
            # pickle sentences and steps
            pickle.dump(all_sentences, open('sentences.p', 'wb'))
            pickle.dump(step_intervals, open('steps.p', 'wb'))

            # pickle accuracy and loss
            pickle.dump(accuracy_list, open('accuracies.p', 'wb'))
            pickle.dump(loss_list, open('loss.p', 'wb'))

            # save model

            Modelname = 'TrainIntervalModel' + str(epoch) + 'acc:' + str(
                accuracy) + '.pt'
            torch.save(model, Modelname)

            model_info = [str(epoch), str(accuracy)]
            pickle.dump(model_info, open('model_info.p', 'wb'))

    print('Done training.')
Example #4
0
def train(config):

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)  # fixme
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length,
                                dataset.vocab_size).to(device=device)  # fixme

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()  # fixme
    optimizer = optim.Adam(model.parameters(),
                           lr=config.learning_rate)  # fixme

    saved_step = 0

    checkpoints = sorted([int(cp) for cp in os.listdir('checkpoints')])
    if checkpoints:
        state = torch.load('checkpoints/{}'.format(checkpoints[-1]))
        saved_step = state['step'] + 1
        model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):
        step = step + saved_step
        # Only for time measurement of step through network
        t1 = time.time()

        #######################################################
        # Add more code here ...
        #######################################################
        model.zero_grad()
        model.hidden = model.init_hidden(config.batch_size)

        inputs = torch.unsqueeze(torch.stack(batch_inputs),
                                 2).float().to(device=device)
        targets = torch.cat(batch_targets).to(device=device)

        out = model(inputs)
        batch_loss = criterion(out, targets)

        # optimizer.zero_grad()
        batch_loss.backward()

        optimizer.step()

        loss = batch_loss  # fixme

        o = torch.max(out, 1)[1].cpu().numpy()
        t = targets.cpu().numpy()
        compared = np.equal(o, t)
        correct = np.sum(compared)
        accuracy = correct / len(compared)

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)

        if step % config.print_every == 0:
            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    int(config.train_steps), config.batch_size,
                    examples_per_second, accuracy, loss))

        if step % config.sample_every == 0:
            # Generate some sentences by sampling from the model
            random_ix = torch.tensor(
                [random.choice(list(dataset._ix_to_char.keys()))])
            ix_list = [random_ix]

            model.hidden = model.init_hidden(1)
            for i in range(config.seq_length):
                tensor = torch.unsqueeze(torch.unsqueeze(ix_list[-1], 0),
                                         0).float().to(device=config.device)
                out = model(tensor)
                o = torch.max(out, 1)[1]
                ix_list.append(o)
            char_ix = [x.cpu().numpy()[0] for x in ix_list]
            gen_sen = dataset.convert_to_string(char_ix)
            with open('generated_sentences.txt', 'a') as file:
                file.write('{}: {}\n'.format(step, gen_sen))

        if step % config.save_every == 0:
            state = {
                'step': step,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state, 'checkpoints/{}'.format(step))

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')