Beispiel #1
0
def run(model_type, input_length, input_dim, num_classes, num_hidden,
        batch_size, learning_rate, train_steps, max_norm, device):
    assert model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    device = torch.device(device)

    # Initialize the model that we are going to use
    model_pars = [
        input_length, input_dim, num_hidden, num_classes, batch_size, device
    ]
    model = LSTM(*model_pars) \
        if model_type == 'LSTM' \
        else VanillaRNN(*model_pars)
    model.to(device)

    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(input_length + 1)
    data_loader = DataLoader(dataset, batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        # Add more code here ...
        optimizer.zero_grad()
        ys = model.forward(batch_inputs)

        ############################################################################
        # QUESTION: what happens here and why?
        ############################################################################
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
        ############################################################################

        # Add more code here ...
        predictions = ys.argmax(dim=-1)
        loss = criterion(ys, batch_targets)
        loss.backward()
        optimizer.step()
        accuracy = (batch_targets == predictions).float().mean()

        # Just for time measurement
        t2 = time.time()
        examples_per_second = batch_size / float(t2 - t1)

        stats = {'loss': loss, 'accuracy': accuracy}

        if step % 10 == 0:

            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    train_steps, batch_size, examples_per_second, accuracy,
                    loss))

        if step == train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')
    return (accuracy.item())
class PalindromeExperiment(PytorchExperiment):
    def setup(self):
        self.save_checkpoint(name='setup')
        (model_type, input_length, input_dim, num_classes, num_hidden, batch_size, learning_rate, train_steps, max_norm, wanted_device) = itemgetter(*flags)(vars(self.config))

        assert model_type in ('RNN', 'LSTM')

        # Initialize the device which to run the model on
        # TODO: debug CUDA issues
        device = torch.device(wanted_device)
        # device = torch.device(device if torch.cuda.is_available() else 'cpu')

        # Initialize the model that we are going to use
        model_pars = [input_length, input_dim, num_hidden, num_classes, batch_size, device]
        self.model = LSTM(*model_pars) \
            if model_type == 'LSTM' \
            else VanillaRNN(*model_pars)
        self.model.to(device)

        # Initialize the dataset and data loader (note the +1)
        dataset = PalindromeDataset(input_length+1)
        self.data_loader = DataLoader(dataset, batch_size, num_workers=1)

        # Setup the loss and optimizer
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=learning_rate)

    # TODO: plot accuracy over input_length
    # TODO: increase learning_rate over input_length
    # TODO: compare result with RNN
    def train(self, epoch):
        (model_type, input_length, input_dim, num_classes, num_hidden, batch_size, learning_rate, train_steps, max_norm, device) = itemgetter(*flags)(vars(self.config))

        with SummaryWriter('part1/train') as w:
            results = []
            for step, (batch_inputs, batch_targets) in enumerate(self.data_loader):

                # Only for time measurement of step through network
                t1 = time.time()

                # Add more code here ...
                self.optimizer.zero_grad()

                # move to device
                inputs =  torch.tensor(batch_inputs,  dtype=torch.float).to(device)
                targets = torch.tensor(batch_targets, dtype=torch.long ).to(device)

                ys = self.model.forward(inputs)

                # clip the gradients so gradient explosion won't let us overshoot the minimum
                # https://www.quora.com/What-is-gradient-clipping-and-why-is-it-necessary
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=max_norm)

                # Add more code here ...
                predictions = ys.argmax(dim=-1)
                loss = self.criterion(ys, targets)
                loss.backward()
                self.optimizer.step()
                accuracy = (targets == predictions).float().mean()

                # Just for time measurement
                t2 = time.time()
                examples_per_second = batch_size/float(t2-t1)

                stats = {'loss':loss, 'accuracy':accuracy}
                results.append({'step': step, **{k:v.item() for k,v in stats.items()}})

                if step % 100 == 0:
                    w.add_scalars('metrics', stats, int(step/10))

                    # # TODO: check why this is slow!
                    # for k, v in stats.items():
                    #     self.add_result(value=v.item(), name=f'train_{k}', counter=step / train_steps, label=k)

                    self.elog.print("elog [{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                        "Accuracy = {:.2f}, Loss = {:.3f}".format(
                            datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                            train_steps, batch_size, examples_per_second,
                            accuracy, loss
                    ))

                    self.save_checkpoint(name='train', n_iter=step)

                if step % 100 == 0:
                    results = write_csv(results, self.config)

                if step == train_steps:
                    # If you receive a PyTorch data-loader error, check this bug report:
                    # https://github.com/pytorch/pytorch/pull/9655
                    break

        print('Done training.')
        results = write_csv(results, self.config)

    def validate(self, epoch):
        pass