Exemplo n.º 1
0
def generate(_run):
    config = argparse.Namespace(**_run.config)

    # Load trained model and vobulary mappings
    checkpoint = torch.load(config.model_path, map_location='cpu')
    model = TextGenerationModel(**checkpoint['hparams'])
    model.load_state_dict(checkpoint['state_dict'])
    ix_to_char = checkpoint['ix_to_char']
    char_to_ix = {v: k for k, v in ix_to_char.items()}

    # Prepare initial sequence
    x0 = torch.tensor([char_to_ix[char] for char in config.initial_seq],
                      dtype=torch.long).view(-1, 1)

    # Generate
    samples = model.sample(x0, config.sample_length,
                           config.temperature).detach().cpu().squeeze()

    print('\n\nSample:')
    print('-------')
    text = ''.join(ix_to_char[ix.item()] for ix in samples)
    print(text)
    print('-------\n\n')
Exemplo n.º 2
0
def train(config):
    # Initialize the device which to run the model on
    # device = torch.device(config.device)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = TextDataset(filename=config.txt_file,
                          seq_length=config.seq_length)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    VOCAB_SIZE = dataset.vocab_size
    CHAR2IDX = dataset._char_to_ix
    IDX2CHAR = dataset._ix_to_char

    # Initialize the model that we are going to use
    model = TextGenerationModel(batch_size=config.batch_size,
                                seq_length=config.seq_length,
                                vocabulary_size=VOCAB_SIZE,
                                lstm_num_hidden=config.lstm_num_hidden,
                                lstm_num_layers=config.lstm_num_layers,
                                device=device)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    scheduler = scheduler_lib.StepLR(optimizer=optimizer,
                                     step_size=config.learning_rate_step,
                                     gamma=config.learning_rate_decay)

    if True:
        model.load_state_dict(
            torch.load('grimm-results/intermediate-model-epoch-30-step-0.pth',
                       map_location='cpu'))
        optimizer.load_state_dict(
            torch.load("grimm-results/intermediate-optim-epoch-30-step-0.pth",
                       map_location='cpu'))

        print("Loaded it!")

    model = model.to(device)

    EPOCHS = 50

    for epoch in range(EPOCHS):
        # initialization of state that's given to the forward pass
        # reset every epoch
        h, c = model.reset_lstm(config.batch_size)
        h = h.to(device)
        c = c.to(device)

        for step, (batch_inputs, batch_targets) in enumerate(data_loader):

            # Only for time measurement of step through network
            t1 = time.time()

            model.train()

            optimizer.zero_grad()

            x = torch.stack(batch_inputs, dim=1).to(device)

            if x.size()[0] != config.batch_size:
                print("We're breaking because something is wrong")
                print("Current batch is of size {}".format(x.size()[0]))
                print("Supposed batch size is {}".format(config.batch_size))
                break

            y = torch.stack(batch_targets, dim=1).to(device)

            x = one_hot_encode(x, VOCAB_SIZE)

            output, (h, c) = model(x=x, prev_state=(h, c))

            loss = criterion(output.transpose(1, 2), y)

            accuracy = calculate_accuracy(output, y)
            h = h.detach()
            c = c.detach()
            loss.backward()
            # add clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=config.max_norm)
            optimizer.step()
            scheduler.step()

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if step % config.print_every == 0:
                #TODO FIX THIS PRINTING
                print(
                    f"Epoch {epoch} Train Step {step}/{config.train_steps}, Examples/Sec = {examples_per_second}, Accuracy = {accuracy}, Loss = {loss}"
                )
                #
                # print("[{}]".format(datetime.now().strftime("%Y-%m-%d %H:%M")))
                # print("[{}] Train Step {:04f}/{:04f}, Batch Size = {}, Examples/Sec = {:.2f}, Accuracy = {:.2f}, Loss = {:.3f}".format(
                #     datetime.now().strftime("%Y-%m-%d %H:%M"), step, config.train_steps, config.batch_size, examples_per_second, accuracy, loss
                # ))

                # print(loss)

            if step % config.sample_every == 0:
                FIRST_CHAR = 'I'  # Is randomized within the prediction, actually
                predict(device, model, FIRST_CHAR, VOCAB_SIZE, IDX2CHAR,
                        CHAR2IDX)
                # Generate some sentences by sampling from the model
                path_model = 'intermediate-model-epoch-{}-step-{}.pth'.format(
                    epoch, step)
                path_optimizer = 'intermediate-optim-epoch-{}-step-{}.pth'.format(
                    epoch, step)
                torch.save(model.state_dict(), path_model)
                torch.save(optimizer.state_dict(), path_optimizer)

            if step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                break

    print('Done training.')