示例#1
0
def train(config):

    # Initialize the device which to run the model on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)
    model = TextGenerationModel(config.batch_size,
                                config.seq_length,
                                dataset.vocab_size,
                                lstm_num_hidden=config.lstm_num_hidden,
                                lstm_num_layers=config.lstm_num_layers,
                                device=device)

    # Setup the loss and optimizer
    criterion = CrossEntropyLoss()
    optimizer = RMSprop(model.parameters(), lr=config.learning_rate)

    realsteps = 0
    for epoch in range(1000):
        for step, (batch_inputs, batch_targets) in enumerate(data_loader):
            realsteps += 1
            step = realsteps
            t1 = time.time()

            batch_targets = torch.stack(batch_targets)
            batch_targets.to(device)
            optimizer.zero_grad()
            print(len(batch_inputs), len(batch_inputs[0]))
            if (len(batch_inputs[0]) < 64):
                continue
            probs = model.forward(batch_inputs)

            loss = 0
            accuracy = 0
            for prob, target in zip(probs, batch_targets):
                # prediction = torch.argmax(prob, dim=1).float()
                loss += criterion.forward(prob, target)
                predictions = prob.argmax(dim=1).float()
                accuracy += float(torch.sum(
                    predictions == target.float())) / config.batch_size
            loss = loss / config.seq_length
            loss.backward()
            writer.add_scalar('Train/Loss', loss, realsteps)
            writer.add_scalar('Train/Accurac3y', accuracy, realsteps)
            optimizer.step()
            accuracy = accuracy / config.seq_length

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if step % 10000 == 0:
                torch.save(model, './' + str(step))
            if step % config.print_every == 0:

                print(
                    "[{}] Train Step {:04d}/{:04f}, Batch Size = {}, Examples/Sec = {:.2f}, "
                    "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                        config.train_steps, config.batch_size,
                        examples_per_second, accuracy, loss))

            # if step % config.sample_every == 0:
            # Generate some sentences by sampling from the model
            # greedy_sampling_model(model, dataset)
            if realsteps > config.train_steps:
                break

        if realsteps > config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')
示例#2
0
def train(config):

    assert config.model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    device = torch.device(config.device)
    # Initialize the model that we are going to use
    if config.model_type == 'RNN':
        model = VanillaRNN(config.input_length, config.input_dim,
                           config.num_hidden, config.num_classes,
                           config.batch_size, device)
    else:
        model = LSTM(config.input_length,
                     config.input_dim,
                     config.num_hidden,
                     config.num_classes,
                     config.batch_size,
                     device=device)
    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = CrossEntropyLoss()
    optimizer = RMSprop(model.parameters(), lr=config.learning_rate)

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        prob = model.forward(batch_inputs)
        ############################################################################
        # QUESTION: what happens here and why? Done to avoid vanishing gradients
        ############################################################################
        torch.nn.utils.clip_grad_norm(model.parameters(),
                                      max_norm=config.max_norm)
        ############################################################################
        loss = criterion.forward(prob, batch_targets)
        accuracy = float(
            torch.sum(prob.argmax(dim=1) == batch_targets)) / config.batch_size
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)
        writer.add_scalar('Train/Accuracy', accuracy, step)

        if step % 10 == 0:

            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss))

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break
    print('Done training.')