def evaluate(self, dataset, step, partition='valid'):
        # Evaluate on validation set
        self.model_lm.eval()
        predictions = []
        labels = []
        with torch.no_grad():
            for validation_batch in dataset:
                validation_batch = batch_filter_distances(validation_batch, self.relative_distances)
                validation_batch = batch_to_device(validation_batch, self.device)
                output = self.model_lm.forward_batch(validation_batch).cpu()
                validation_batch = batch_to_device(validation_batch, "cpu")

                predictions.extend(output.logits.argmax(-1))
                labels.extend(validation_batch.labels)

                evaluation = self._evaluate_predictions(output.logits, validation_batch.labels,
                                                        loss=output.loss, partition=partition)
                self.logger.log_sub_batch_metrics(evaluation)

                self.logger.log_text("predicted words/validation",
                                     str(self._decode_predicted_words(output, validation_batch)))
                self.logger.log_text("labels/validation", str(self._decode_labels(validation_batch)))
        self.model_lm.train()
        if f"micro_f1_score/{partition}" in self.logger.sub_batch_metrics:
            score = mean(self.logger.sub_batch_metrics[f"micro_f1_score/{partition}"])
        else:
            score = mean(self.logger.sub_batch_metrics[f"f1_score/{partition}"])
        self.logger.flush_batch_metrics(step=step)

        return score
    def _train_step(self, batch, num_simulated_batches):
        batch = batch_to_device(batch, self.device)
        output_gpu = self.model_lm.forward_batch(batch)
        # Gradient accumulation: every batch contributes only a part of the total gradient
        (output_gpu.loss / num_simulated_batches).backward()
        output_cpu = output_gpu.cpu()
        del output_gpu
        del batch

        return output_cpu
Example #3
0
    pad_id = word_vocab[PAD_TOKEN]
    unk_id = word_vocab[UNKNOWN_TOKEN]

    f1_scores = []
    precisions = []
    recalls = []
    predictions = []
    best_non_unk_predictions = []
    labels = []
    losses = []
    progress = tqdm(enumerate(dataloader), total=int(data_manager.approximate_total_samples() / BATCH_SIZE))
    for i, batch in progress:
        batch = batch_filter_distances(batch, relative_distances)
        if not args.no_gpu:
            batch = batch_to_device(batch)

        label = batch.labels.detach().cpu()

        with torch.no_grad():
            output = model.forward_batch(batch).cpu()
        losses.append(output.loss.item())
        f1, prec, rec = f1_score(output.logits, label, pad_id=pad_id, unk_id=unk_id,
                                 output_precision_recall=True)
        f1_scores.append(f1)
        precisions.append(prec)
        recalls.append(rec)

        batch_logits = output.logits.detach().cpu()
        best_non_unk_predictions.extend(get_best_non_unk_predictions(output.logits, unk_id=unk_id))
        predictions.extend(batch_logits.argmax(-1).squeeze(1))
Example #4
0
    def test_mini_dataset(self):
        def evaluate_predictions(logits, labels, loss=None):
            correct = logits.argmax(-1) == labels
            all_correct = correct.prod(-1)
            correct_tokens = all_correct.float().mean().cpu().item()
            ret = dict(correct_tokens=correct_tokens)
            if loss is not None:
                ret['loss'] = loss.detach().cpu().item()
            return ret

        BATCH_SIZE = 13
        NUM_PREDICT = 5

        dataloader = self.setup_mini_dataset()

        config = CodeTransformerCoreConfig(
            encoder_layer=CodeTransformerLayerConfig(d_model=16,
                                                     nhead=8,
                                                     dim_feedforward=32,
                                                     activation="gelu",
                                                     num_relative_distances=4,
                                                     use_token_distances=True,
                                                     use_content_content=True,
                                                     use_content_pos=True,
                                                     use_pos_content=True,
                                                     use_pos_pos=True),
            num_layers=4,
        )

        language_model_config = TransformerLMDecoderConfig(
            lm_encoder=TransformerLMEncoderConfig(
                config,
                vocab_size=len(self.word_vocab.vocabulary),
                num_node_types=len(self.node_type_vocab.vocabulary),
                num_token_types=len(self.token_type_vocab.vocabulary)),
            sos_id=-1)
        transformer_lm = TransformerLanguageModel(
            transformer_lm_encoder=language_model_config['lm_encoder'],
            output_nonlinearity=language_model_config['output_nonlinearity'],
            loss_fct=language_model_config['loss_fct'])
        batch: CTBatch = next(iter(dataloader))

        cuda = torch.cuda.is_available() and RUN_TESTS_ON_GPU
        if cuda:
            transformer_lm = transformer_lm.cuda()

        opt = optim.Adam(transformer_lm.parameters(), lr=1e-4)
        tq = tqdm(range(500))

        if RUN_TESTS_ON_GPU:
            with self.assertRaises(RuntimeError):
                # CPU input on CUDA model should fail
                output = transformer_lm.forward_batch(batch)
            batch = batch_to_device(batch, "cuda")

        assert not (batch.labels == self.word_vocab['</s>']).any().item()
        for _ in tq:
            output = transformer_lm.forward_batch(batch)
            output.loss.backward()
            opt.step()
            opt.zero_grad()
            evaluation = evaluate_predictions(output.logits, batch.labels)
            acc = evaluation['correct_tokens']
            tq.set_postfix(loss=output.loss.cpu().item(), acc=acc)

            predicted_tokens = output.logits.argmax(-1)
            generated_text = batch_decode(self.word_vocab, predicted_tokens)
            generated_text2 = [
                " ".join([
                    "_".join([
                        self.word_vocab.reverse_lookup(subtoken.item())
                        for subtoken in token
                    ]) for token in sample
                ]) for sample in predicted_tokens
            ]
            assert list(generated_text) == generated_text2
        assert acc > 0.98
    def train(self, batch_size, simulated_batch_size, random_seed, metrics,
              validate_every=None,
              persistent_snapshot_every=None, simulated_batch_size_valid=None, early_stopping_patience=10,
              max_validation_samples=10000, accumulate_tokens_batch=False):

        if self.with_cuda:
            self.model_lm = self.model_lm.cuda()
            self.device = "cuda"
        else:
            self.device = "cpu"

        run_id = self.model_manager.generate_run_name()

        self.logger = ExperimentLogger("experiment",
                                       TensorboardLogger(f"{LOGS_PATH}/{self.model_manager.model_type}/{run_id}"))
        self.logger.info(f"===============================================")
        self.logger.info(f"Starting run {run_id}")
        self.logger.info(f"===============================================")

        self.model_manager.save_config(run_id, self.config)
        early_stopping = EarlyStopping(self.model_manager, run_id, early_stopping_patience)

        num_params = sum([len(params.view(-1)) for params in self.model_lm.parameters()])
        self.logger.info(f"Start training model with {num_params} parameters")
        self.logger.info(f"Model setup: {self.model_lm}")

        self._init_metrics(metrics)

        torch.manual_seed(random_seed)
        random.seed(random_seed)

        # Simulated batches
        simulated_batch_size = batch_size if simulated_batch_size is None else simulated_batch_size
        assert simulated_batch_size % batch_size == 0, "simulated_batch_size must be a multiple of batch_size"
        num_simulated_batches = simulated_batch_size // batch_size

        # Main train loop
        train_step = 0
        dataloader = DataLoader(self.dataset_train, batch_size=batch_size, collate_fn=self.dataset_train.collate_fn)

        if self.use_validation:
            if simulated_batch_size_valid is None:
                simulated_batch_size_valid = simulated_batch_size
            num_simulated_batches_valid = simulated_batch_size_valid // batch_size
            dataloader_validation = iter(DataLoader(self.dataset_validation, batch_size=batch_size,
                                                    collate_fn=self.dataset_validation.collate_fn))

        n_tokens_accumulate_batch = None
        if accumulate_tokens_batch:
            n_tokens_accumulate_batch = 0

        epoch = 1
        progress_bar = tqdm(total=int(self.data_manager.approximate_total_samples() / batch_size))
        progress_bar.set_description(f"Epoch {epoch}")

        # Ensure graceful shutdown when training is interrupted
        signal.signal(signal.SIGINT, self._handle_shutdown)

        with Timing() as t:
            for it, batch in enumerate(dataloader):
                self.logger.log_time(t.measure() / batch_size, "dataloader_seconds/sample",
                                     train_step * simulated_batch_size + (it % num_simulated_batches) * batch_size)
                # Calculate gradients
                batch = batch_filter_distances(batch, self.relative_distances)
                model_out = self._train_step(batch, num_simulated_batches)
                self.logger.log_time(t.measure() / batch_size, "model_seconds/sample",
                                     train_step * simulated_batch_size + (it % num_simulated_batches) * batch_size)

                # Log actual predicted words and labels
                self.logger.log_text("input/train",
                                     str([[self.word_vocab.reverse_lookup(st.item()) for st in token
                                           if st.item() != self.word_vocab[PAD_TOKEN]
                                           and st.item() != self.word_vocab[EOS_TOKEN]]
                                          for token in batch.tokens[0]]))
                self.logger.log_text("predicted words/train", str(self._decode_predicted_words(model_out, batch)))
                self.logger.log_text("labels/train", str(self._decode_labels(batch)))

                # Calculate metrics
                evaluation = self._evaluate_predictions(model_out.logits, batch.labels, loss=model_out.loss)
                self.logger.log_sub_batch_metrics(evaluation)

                if accumulate_tokens_batch:
                    n_tokens_accumulate_batch += batch.sequence_lengths.sum().item()

                # Gradient accumulation: only update gradients every num_simulated_batches step
                if not accumulate_tokens_batch and it % num_simulated_batches == (num_simulated_batches - 1) \
                        or accumulate_tokens_batch and n_tokens_accumulate_batch > simulated_batch_size:
                    if accumulate_tokens_batch:
                        n_tokens_accumulate_batch = 0
                    train_step += 1

                    total_norm = 0
                    for p in self.model_lm.parameters():
                        if p.grad is not None:
                            param_norm = p.grad.data.norm(2)
                            total_norm += param_norm.item() ** 2
                    total_norm = total_norm ** (1. / 2)
                    self.logger.log_metrics({'gradient_norm': total_norm}, train_step * simulated_batch_size)

                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    if self.scheduler:
                        if not hasattr(self.scheduler,
                                       "total_steps") or train_step < self.scheduler.total_steps - 1:
                            self.scheduler.step()
                        self.logger.log_metrics({'lr': self.scheduler.get_lr()[0]},
                                                train_step * simulated_batch_size)

                    # Send train metrics to observers
                    self.logger.flush_batch_metrics(train_step * simulated_batch_size)

                    # Evaluate on validation set
                    if self.use_validation and validate_every and train_step % validate_every == 0:
                        t.measure()
                        self.model_lm.eval()
                        with torch.no_grad():
                            for validation_batch in islice(dataloader_validation, num_simulated_batches_valid):
                                validation_batch = batch_filter_distances(validation_batch, self.relative_distances)
                                validation_batch = batch_to_device(validation_batch, self.device)
                                output = self.model_lm.forward_batch(validation_batch).cpu()
                                validation_batch = batch_to_device(validation_batch, "cpu")

                                evaluation = self._evaluate_predictions(output.logits, validation_batch.labels,
                                                                        loss=output.loss, partition='valid')
                                self.logger.log_sub_batch_metrics(evaluation)

                                self.logger.log_text("predicted words/validation",
                                                     str(self._decode_predicted_words(output, validation_batch)))
                                self.logger.log_text("labels/validation",
                                                     str(self._decode_labels(validation_batch)))
                        self.model_lm.train()
                        self.logger.flush_batch_metrics(step=train_step * simulated_batch_size)
                        self.logger.log_time(t.measure() / simulated_batch_size_valid, "valid_seconds/sample",
                                             train_step * simulated_batch_size)

                if persistent_snapshot_every and (it + 1) % persistent_snapshot_every == 0:
                    snapshot_iteration = it + 1
                    self.logger.info(f"Storing model params into snapshot-{snapshot_iteration}")
                    self.model_manager.save_snapshot(run_id, self.model_lm.state_dict(), snapshot_iteration)
                    dataset = self.dataset_validation_creator(False)
                    score = self.evaluate(islice(dataset.to_dataloader(), int(max_validation_samples / batch_size)),
                                          train_step * simulated_batch_size, 'valid_full')
                    if f"micro_f1_score/valid_full" in self.logger.sub_batch_metrics:
                        score_name = 'micro-F1'
                    else:
                        score_name = 'F1'
                    self.logger.info(f"Full evaluation yielded {score} {score_name}")
                    if not early_stopping.evaluate(score, snapshot_iteration):
                        self.logger.info(f"Last {early_stopping_patience} evaluations did not improve performance. "
                                         f"Stopping run")

                        break

                progress_bar.update()
                if progress_bar.n >= progress_bar.total:
                    progress_bar = tqdm(total=int(self.data_manager.approximate_total_samples() / batch_size))
                    epoch += 1
                    progress_bar.set_description(f"Epoch {epoch}")

            t.measure()

        self._handle_shutdown()