Beispiel #1
0
with torch.no_grad():
    # Generate a initial sequence
    # text = "The old woman had only pretend"
    text = "Then first came two white dove"
    # text = "At first Rapunzel was terribly"
    print(text)
    init_seq = torch.tensor(dataset.convert_to_index(text))

    # init_seq = t
    init_seq = torch.zeros(config.batch_size, config.seq_length,
                           vocab_size).scatter_(
                               2,
                               torch.unsqueeze(torch.unsqueeze(init_seq, 0),
                                               2), 1).to(device)

    model.hidden = model.init_hidden(config.batch_size)

    # send the first sequence
    predictions, hidden = model(init_seq)

    _, idx = torch.max(predictions, dim=2)

    idx = idx.squeeze(0)[-1]
    idx = idx.reshape(1, 1, 1)
    chars_ix = [idx.item()]
    in_char = torch.zeros(1, 1, vocab_size, device=device).scatter_(2, idx, 1)

    # model.hidden = hidden
    in_char = in_char.to(device)

    for i in range(config.generate_length):
Beispiel #2
0
def train(config):
    
    
    # Initialize the device which to run the model on
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)   # fixme
    data_loader = DataLoader(dataset, batch_size = config.batch_size, shuffle=True, num_workers=1)
    vocab_size = dataset.vocab_size
    # char2i = dataset._char_to_ix
    # i2char = dataset._ix_to_char
    # ----------------------------------------
    
    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length, vocab_size, \
                                config.lstm_num_hidden, config.lstm_num_layers, device)  # fixme
    model.to(device)

    # Setup the loss and optimizer
    criterion = nn.NLLLoss()  # fixme
    optimizer = optim.RMSprop(model.parameters(), lr = config.learning_rate)  # fixme
    logSoftmax = nn.LogSoftmax(dim=2)
    
    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, \
                  step_size=config.learning_rate_step, gamma=config.learning_rate_decay)
    step = 1
    
    if config.resume:
        if os.path.isfile(config.resume):
            print("Loading checkpoint '{}'".format(config.resume))
            checkpoint = torch.load(config.resume)
            step = checkpoint['step']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            print("Checkpoint loaded '{}', steps {}".format(config.resume, checkpoint['step']))

    if not os.path.isdir(config.summary_path):
            os.makedirs(config.summary_path)

    if config.sampling =="greedy":
        
        f = open(os.path.join(config.summary_path,"sampled_"+config.sampling+".txt"), "w+")
    else:
        f = open(os.path.join(config.summary_path,"sampled_"+config.sampling+"_"+str(config.temp)+".txt"), "w+")



    
   
    best_accuracy = 0.0
    pl_loss =[]
    average_loss =[]
    acc =[]

    for epochs in range(30):

        if step == config.train_steps:
            print('Done training.')
            break

        for (batch_inputs, batch_targets) in data_loader:

            if config.batch_size!=batch_inputs.size()[0]:
                print("batch mismatch")
                break

            # Only for time measurement of step through network
            t1 = time.time()
            model.hidden = model.init_hidden(config.batch_size)

            model.zero_grad()
            #######################################################
            # Add more code here ...
            
            #convert batch inputs to one-hot vector
            batch_inputs= torch.zeros(config.batch_size, config.seq_length, vocab_size).scatter_(2,batch_inputs.unsqueeze(-1),1.0)
            
            batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)

            predictions, _ = model(batch_inputs)
            if config.sampling=="greedy":
                predictions = logSoftmax(predictions)
            else:
                predictions = logSoftmax(predictions/config.temp)

            loss = criterion(predictions.transpose(2,1), batch_targets)   # fixme

            _, predictions = torch.max(predictions, dim=2, keepdim=True)
            predictions = (predictions.squeeze(-1) == batch_targets).float()
            accuracy = torch.mean(predictions)
            
            
            
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=config.max_norm)
            
            optimizer.step()
            lr_scheduler.step()

            #######################################################

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size/float(t2-t1)
            pl_loss.append(loss.item())
            average_loss.append(np.mean(pl_loss[:-100:-1]))
            acc.append(accuracy)


            if step % config.print_every == 0:

                print("[{}] Train Step {}/{}, Batch Size = {}, Examples/Sec = {:.2f}, "
                    "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                        config.train_steps, config.batch_size, examples_per_second,
                        accuracy, loss.item()
                ))
                
                

            if step % config.sample_every == 0:
                               
                model.eval()
               
                with torch.no_grad():
                   char_ix = generate_sample(model, vocab_size, config.seq_length, device, config)
                   sentence = dataset.convert_to_string(char_ix) 
                           
            
                f.write("--------------"+str(step)+"----------------\n")
                f.write(sentence+"\n")
                print(sentence)
                print()
                model.train()
                # ###########################################################################
                # save training loss
                plt.plot(pl_loss,'r-', label="Batch loss", alpha=0.5)
                plt.plot(average_loss,'g-', label="Average loss", alpha=0.5)
                plt.legend()
                plt.xlabel("Iterations")
                plt.ylabel("Loss")  
                plt.title("Training Loss")
                plt.grid(True)
                # plt.show()
                if config.sampling == "greedy":
                    plt.savefig("loss_"+config.sampling+".png")
                else:
                    plt.savefig("loss_"+config.sampling+"_"+str(config.temp)+".png")

                plt.close()
                ################################training##################################################
                plt.plot(acc,'g-', alpha=0.5)
                plt.xlabel("Iterations")
                plt.ylabel("Accuracy")
                plt.title("Train Accuracy")
                plt.grid(True)
                if config.sampling == "greedy":
                    plt.savefig("accuracy_"+config.sampling+".png")
                else:
                    plt.savefig("accuracy_"+config.sampling+"_"+str(config.temp)+".png")
                plt.close()

            if step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                break
            
            step+=1
            
        save_checkpoint({
            'epoch': epochs + 1,
            'step': step,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler':lr_scheduler.state_dict(),
            'accuracy': accuracy
                }, config)
        
    f.close()
Beispiel #3
0
def train(config):

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)  # fixme
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length,
                                dataset.vocab_size).to(device=device)  # fixme

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()  # fixme
    optimizer = optim.Adam(model.parameters(),
                           lr=config.learning_rate)  # fixme

    saved_step = 0

    checkpoints = sorted([int(cp) for cp in os.listdir('checkpoints')])
    if checkpoints:
        state = torch.load('checkpoints/{}'.format(checkpoints[-1]))
        saved_step = state['step'] + 1
        model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):
        step = step + saved_step
        # Only for time measurement of step through network
        t1 = time.time()

        #######################################################
        # Add more code here ...
        #######################################################
        model.zero_grad()
        model.hidden = model.init_hidden(config.batch_size)

        inputs = torch.unsqueeze(torch.stack(batch_inputs),
                                 2).float().to(device=device)
        targets = torch.cat(batch_targets).to(device=device)

        out = model(inputs)
        batch_loss = criterion(out, targets)

        # optimizer.zero_grad()
        batch_loss.backward()

        optimizer.step()

        loss = batch_loss  # fixme

        o = torch.max(out, 1)[1].cpu().numpy()
        t = targets.cpu().numpy()
        compared = np.equal(o, t)
        correct = np.sum(compared)
        accuracy = correct / len(compared)

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)

        if step % config.print_every == 0:
            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    int(config.train_steps), config.batch_size,
                    examples_per_second, accuracy, loss))

        if step % config.sample_every == 0:
            # Generate some sentences by sampling from the model
            random_ix = torch.tensor(
                [random.choice(list(dataset._ix_to_char.keys()))])
            ix_list = [random_ix]

            model.hidden = model.init_hidden(1)
            for i in range(config.seq_length):
                tensor = torch.unsqueeze(torch.unsqueeze(ix_list[-1], 0),
                                         0).float().to(device=config.device)
                out = model(tensor)
                o = torch.max(out, 1)[1]
                ix_list.append(o)
            char_ix = [x.cpu().numpy()[0] for x in ix_list]
            gen_sen = dataset.convert_to_string(char_ix)
            with open('generated_sentences.txt', 'a') as file:
                file.write('{}: {}\n'.format(step, gen_sen))

        if step % config.save_every == 0:
            state = {
                'step': step,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state, 'checkpoints/{}'.format(step))

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')