Beispiel #1
0
    def train_epoch(self, corpus, train_data, criterion, optimizer, epoch, args):
        self.train()
        dataset_size = len(train_data.data())  # this will be approximate
        total_loss = 0
        total_tokens = 0
        batch_idx = 0
        for batch in train_data:
            if epoch > args.kl_anneal_delay and args.anneal < 1.:
                args.anneal += args.kl_anneal_rate
            optimizer.zero_grad()
            data, targets = batch.text, batch.target
            logits, mean, logvar = self.forward(data, targets, args)
            elbo, NLL, KL, tokens = self.elbo(logits, data, criterion, mean, logvar, args, args.anneal)
            loss = elbo/tokens
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.parameters(), args.clip)
            optimizer.step()

            total_loss += elbo.detach()
            total_tokens += tokens

            # print if necessary
            if batch_idx % args.log_interval == 0 and batch_idx > 0:
                print_in_epoch_summary(epoch, batch_idx, args.batch_size, dataset_size,
                                       loss.data[0], NLL.data[0] / tokens, {'normal': KL.data[0] / tokens},
                                       tokens, "anneal={:.2f}".format(args.anneal))
            batch_idx += 1  # because no cheap generator smh
        return total_loss[0] / total_tokens
        def train_loop(profile=False):
            total_loss = 0
            total_tokens = 0
            total_resamples = 0
            batch_idx = 0

            # for pretty printing the loss in each chunk
            last_chunk_loss = 0
            last_chunk_tokens = 0
            last_chunk_resamples = 0

            for batch in tqdm(train_data):
                if profile and batch_idx > 10:
                    print("breaking because profiling finished;")
                    break
                if epoch > args.kl_anneal_delay:
                    args.anneal = min(args.anneal + args.kl_anneal_rate, 1.)
                optimizer.zero_grad()
                data, targets = batch.text, batch.target
                elbo, NLL, tokens, resamples = self.forward(
                    data, targets, args, num_importance_samples, criterion)
                loss = elbo / tokens
                loss.backward()
                torch.nn.utils.clip_grad_norm(self.parameters(), args.clip)
                optimizer.step()

                total_loss += elbo.detach()
                total_tokens += tokens
                total_resamples += resamples

                # print if necessary
                if batch_idx % args.log_interval == 0 and batch_idx > 0:
                    chunk_loss = total_loss.data[0] - last_chunk_loss
                    chunk_tokens = total_tokens - last_chunk_tokens
                    chunk_resamples = (total_resamples - last_chunk_resamples
                                       ) / args.log_interval
                    print(total_resamples)
                    print_in_epoch_summary(
                        epoch, batch_idx, args.batch_size, dataset_size,
                        loss.data[0], NLL / tokens, {
                            'Chunk Loss': chunk_loss / chunk_tokens,
                            'resamples': chunk_resamples
                        }, tokens, "anneal={:.2f}".format(args.anneal))
                    last_chunk_loss = total_loss.data[0]
                    last_chunk_tokens = total_tokens
                    last_chunk_resamples = total_resamples
                batch_idx += 1  # because no cheap generator smh
            return total_loss, total_tokens