Exemplo n.º 1
0
def train(_run):
    config = argparse.Namespace(**_run.config)

    # Initialize the device
    device = torch.device(config.device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)
    total_samples = int(config.train_steps * config.batch_size)
    sampler = RandomSampler(dataset,
                            replacement=True,
                            num_samples=total_samples)
    data_sampler = BatchSampler(sampler, config.batch_size, drop_last=False)
    data_loader = DataLoader(dataset,
                             num_workers=1,
                             batch_sampler=data_sampler)

    # Initialize the model that we are going to use
    model = TextGenerationModel(dataset.vocab_size, config.lstm_num_hidden,
                                config.lstm_num_layers).to(device)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        # Prepare data
        batch_inputs = torch.stack(batch_inputs).to(device)
        batch_targets = torch.stack(batch_targets).t().to(device)

        # Forward, backward, optimize
        optimizer.zero_grad()
        logits = model(batch_inputs)
        batch_loss = criterion(logits, batch_targets)
        batch_loss.backward()
        optimizer.step()

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)

        if step % config.print_every == 0:
            accuracy = eval_accuracy(logits, batch_targets)
            loss = batch_loss.item()
            log_str = ("[{}] Train Step {:04d}/{:04d}, "
                       "Batch Size = {}, Examples/Sec = {:.2f}, "
                       "Accuracy = {:.2f}, Loss = {:.3f}")
            print(
                log_str.format(datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                               config.train_steps, config.batch_size,
                               examples_per_second, accuracy, loss))

            _run.log_scalar('loss', loss, step)
            _run.log_scalar('acc', accuracy, step)

        if step % config.sample_every == 0:
            # Generate some sentences by sampling from the model
            print('-' * (config.sample_length + 1))
            x0 = torch.randint(low=0, high=dataset.vocab_size, size=(1, 5))
            samples = model.sample(x0, config.sample_length).detach().cpu()
            samples = samples.numpy()

            for sample in samples:
                print(dataset.convert_to_string(sample))

            print('-' * (config.sample_length + 1))

        if step == config.train_steps:
            break

    print('Done training.')
    ckpt_path = os.path.join(SAVE_PATH, str(config.timestamp) + '.pt')
    torch.save(
        {
            'state_dict': model.state_dict(),
            'hparams': model.hparams,
            'ix_to_char': dataset.ix_to_char
        }, ckpt_path)
    print('Saved checkpoint to {}'.format(ckpt_path))
Exemplo n.º 2
0
def train(config):
    # Initialize the device which to run the model on
    # device = torch.device(config.device)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = TextDataset(filename=config.txt_file,
                          seq_length=config.seq_length)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    VOCAB_SIZE = dataset.vocab_size
    CHAR2IDX = dataset._char_to_ix
    IDX2CHAR = dataset._ix_to_char

    # Initialize the model that we are going to use
    model = TextGenerationModel(batch_size=config.batch_size,
                                seq_length=config.seq_length,
                                vocabulary_size=VOCAB_SIZE,
                                lstm_num_hidden=config.lstm_num_hidden,
                                lstm_num_layers=config.lstm_num_layers,
                                device=device)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    scheduler = scheduler_lib.StepLR(optimizer=optimizer,
                                     step_size=config.learning_rate_step,
                                     gamma=config.learning_rate_decay)

    if True:
        model.load_state_dict(
            torch.load('grimm-results/intermediate-model-epoch-30-step-0.pth',
                       map_location='cpu'))
        optimizer.load_state_dict(
            torch.load("grimm-results/intermediate-optim-epoch-30-step-0.pth",
                       map_location='cpu'))

        print("Loaded it!")

    model = model.to(device)

    EPOCHS = 50

    for epoch in range(EPOCHS):
        # initialization of state that's given to the forward pass
        # reset every epoch
        h, c = model.reset_lstm(config.batch_size)
        h = h.to(device)
        c = c.to(device)

        for step, (batch_inputs, batch_targets) in enumerate(data_loader):

            # Only for time measurement of step through network
            t1 = time.time()

            model.train()

            optimizer.zero_grad()

            x = torch.stack(batch_inputs, dim=1).to(device)

            if x.size()[0] != config.batch_size:
                print("We're breaking because something is wrong")
                print("Current batch is of size {}".format(x.size()[0]))
                print("Supposed batch size is {}".format(config.batch_size))
                break

            y = torch.stack(batch_targets, dim=1).to(device)

            x = one_hot_encode(x, VOCAB_SIZE)

            output, (h, c) = model(x=x, prev_state=(h, c))

            loss = criterion(output.transpose(1, 2), y)

            accuracy = calculate_accuracy(output, y)
            h = h.detach()
            c = c.detach()
            loss.backward()
            # add clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=config.max_norm)
            optimizer.step()
            scheduler.step()

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if step % config.print_every == 0:
                #TODO FIX THIS PRINTING
                print(
                    f"Epoch {epoch} Train Step {step}/{config.train_steps}, Examples/Sec = {examples_per_second}, Accuracy = {accuracy}, Loss = {loss}"
                )
                #
                # print("[{}]".format(datetime.now().strftime("%Y-%m-%d %H:%M")))
                # print("[{}] Train Step {:04f}/{:04f}, Batch Size = {}, Examples/Sec = {:.2f}, Accuracy = {:.2f}, Loss = {:.3f}".format(
                #     datetime.now().strftime("%Y-%m-%d %H:%M"), step, config.train_steps, config.batch_size, examples_per_second, accuracy, loss
                # ))

                # print(loss)

            if step % config.sample_every == 0:
                FIRST_CHAR = 'I'  # Is randomized within the prediction, actually
                predict(device, model, FIRST_CHAR, VOCAB_SIZE, IDX2CHAR,
                        CHAR2IDX)
                # Generate some sentences by sampling from the model
                path_model = 'intermediate-model-epoch-{}-step-{}.pth'.format(
                    epoch, step)
                path_optimizer = 'intermediate-optim-epoch-{}-step-{}.pth'.format(
                    epoch, step)
                torch.save(model.state_dict(), path_model)
                torch.save(optimizer.state_dict(), path_optimizer)

            if step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                break

    print('Done training.')