def evaluate(self, dataset, step, partition='valid'): # Evaluate on validation set self.model_lm.eval() predictions = [] labels = [] with torch.no_grad(): for validation_batch in dataset: validation_batch = batch_filter_distances(validation_batch, self.relative_distances) validation_batch = batch_to_device(validation_batch, self.device) output = self.model_lm.forward_batch(validation_batch).cpu() validation_batch = batch_to_device(validation_batch, "cpu") predictions.extend(output.logits.argmax(-1)) labels.extend(validation_batch.labels) evaluation = self._evaluate_predictions(output.logits, validation_batch.labels, loss=output.loss, partition=partition) self.logger.log_sub_batch_metrics(evaluation) self.logger.log_text("predicted words/validation", str(self._decode_predicted_words(output, validation_batch))) self.logger.log_text("labels/validation", str(self._decode_labels(validation_batch))) self.model_lm.train() if f"micro_f1_score/{partition}" in self.logger.sub_batch_metrics: score = mean(self.logger.sub_batch_metrics[f"micro_f1_score/{partition}"]) else: score = mean(self.logger.sub_batch_metrics[f"f1_score/{partition}"]) self.logger.flush_batch_metrics(step=step) return score
def _train_step(self, batch, num_simulated_batches): batch = batch_to_device(batch, self.device) output_gpu = self.model_lm.forward_batch(batch) # Gradient accumulation: every batch contributes only a part of the total gradient (output_gpu.loss / num_simulated_batches).backward() output_cpu = output_gpu.cpu() del output_gpu del batch return output_cpu
pad_id = word_vocab[PAD_TOKEN] unk_id = word_vocab[UNKNOWN_TOKEN] f1_scores = [] precisions = [] recalls = [] predictions = [] best_non_unk_predictions = [] labels = [] losses = [] progress = tqdm(enumerate(dataloader), total=int(data_manager.approximate_total_samples() / BATCH_SIZE)) for i, batch in progress: batch = batch_filter_distances(batch, relative_distances) if not args.no_gpu: batch = batch_to_device(batch) label = batch.labels.detach().cpu() with torch.no_grad(): output = model.forward_batch(batch).cpu() losses.append(output.loss.item()) f1, prec, rec = f1_score(output.logits, label, pad_id=pad_id, unk_id=unk_id, output_precision_recall=True) f1_scores.append(f1) precisions.append(prec) recalls.append(rec) batch_logits = output.logits.detach().cpu() best_non_unk_predictions.extend(get_best_non_unk_predictions(output.logits, unk_id=unk_id)) predictions.extend(batch_logits.argmax(-1).squeeze(1))
def test_mini_dataset(self): def evaluate_predictions(logits, labels, loss=None): correct = logits.argmax(-1) == labels all_correct = correct.prod(-1) correct_tokens = all_correct.float().mean().cpu().item() ret = dict(correct_tokens=correct_tokens) if loss is not None: ret['loss'] = loss.detach().cpu().item() return ret BATCH_SIZE = 13 NUM_PREDICT = 5 dataloader = self.setup_mini_dataset() config = CodeTransformerCoreConfig( encoder_layer=CodeTransformerLayerConfig(d_model=16, nhead=8, dim_feedforward=32, activation="gelu", num_relative_distances=4, use_token_distances=True, use_content_content=True, use_content_pos=True, use_pos_content=True, use_pos_pos=True), num_layers=4, ) language_model_config = TransformerLMDecoderConfig( lm_encoder=TransformerLMEncoderConfig( config, vocab_size=len(self.word_vocab.vocabulary), num_node_types=len(self.node_type_vocab.vocabulary), num_token_types=len(self.token_type_vocab.vocabulary)), sos_id=-1) transformer_lm = TransformerLanguageModel( transformer_lm_encoder=language_model_config['lm_encoder'], output_nonlinearity=language_model_config['output_nonlinearity'], loss_fct=language_model_config['loss_fct']) batch: CTBatch = next(iter(dataloader)) cuda = torch.cuda.is_available() and RUN_TESTS_ON_GPU if cuda: transformer_lm = transformer_lm.cuda() opt = optim.Adam(transformer_lm.parameters(), lr=1e-4) tq = tqdm(range(500)) if RUN_TESTS_ON_GPU: with self.assertRaises(RuntimeError): # CPU input on CUDA model should fail output = transformer_lm.forward_batch(batch) batch = batch_to_device(batch, "cuda") assert not (batch.labels == self.word_vocab['</s>']).any().item() for _ in tq: output = transformer_lm.forward_batch(batch) output.loss.backward() opt.step() opt.zero_grad() evaluation = evaluate_predictions(output.logits, batch.labels) acc = evaluation['correct_tokens'] tq.set_postfix(loss=output.loss.cpu().item(), acc=acc) predicted_tokens = output.logits.argmax(-1) generated_text = batch_decode(self.word_vocab, predicted_tokens) generated_text2 = [ " ".join([ "_".join([ self.word_vocab.reverse_lookup(subtoken.item()) for subtoken in token ]) for token in sample ]) for sample in predicted_tokens ] assert list(generated_text) == generated_text2 assert acc > 0.98
def train(self, batch_size, simulated_batch_size, random_seed, metrics, validate_every=None, persistent_snapshot_every=None, simulated_batch_size_valid=None, early_stopping_patience=10, max_validation_samples=10000, accumulate_tokens_batch=False): if self.with_cuda: self.model_lm = self.model_lm.cuda() self.device = "cuda" else: self.device = "cpu" run_id = self.model_manager.generate_run_name() self.logger = ExperimentLogger("experiment", TensorboardLogger(f"{LOGS_PATH}/{self.model_manager.model_type}/{run_id}")) self.logger.info(f"===============================================") self.logger.info(f"Starting run {run_id}") self.logger.info(f"===============================================") self.model_manager.save_config(run_id, self.config) early_stopping = EarlyStopping(self.model_manager, run_id, early_stopping_patience) num_params = sum([len(params.view(-1)) for params in self.model_lm.parameters()]) self.logger.info(f"Start training model with {num_params} parameters") self.logger.info(f"Model setup: {self.model_lm}") self._init_metrics(metrics) torch.manual_seed(random_seed) random.seed(random_seed) # Simulated batches simulated_batch_size = batch_size if simulated_batch_size is None else simulated_batch_size assert simulated_batch_size % batch_size == 0, "simulated_batch_size must be a multiple of batch_size" num_simulated_batches = simulated_batch_size // batch_size # Main train loop train_step = 0 dataloader = DataLoader(self.dataset_train, batch_size=batch_size, collate_fn=self.dataset_train.collate_fn) if self.use_validation: if simulated_batch_size_valid is None: simulated_batch_size_valid = simulated_batch_size num_simulated_batches_valid = simulated_batch_size_valid // batch_size dataloader_validation = iter(DataLoader(self.dataset_validation, batch_size=batch_size, collate_fn=self.dataset_validation.collate_fn)) n_tokens_accumulate_batch = None if accumulate_tokens_batch: n_tokens_accumulate_batch = 0 epoch = 1 progress_bar = tqdm(total=int(self.data_manager.approximate_total_samples() / batch_size)) progress_bar.set_description(f"Epoch {epoch}") # Ensure graceful shutdown when training is interrupted signal.signal(signal.SIGINT, self._handle_shutdown) with Timing() as t: for it, batch in enumerate(dataloader): self.logger.log_time(t.measure() / batch_size, "dataloader_seconds/sample", train_step * simulated_batch_size + (it % num_simulated_batches) * batch_size) # Calculate gradients batch = batch_filter_distances(batch, self.relative_distances) model_out = self._train_step(batch, num_simulated_batches) self.logger.log_time(t.measure() / batch_size, "model_seconds/sample", train_step * simulated_batch_size + (it % num_simulated_batches) * batch_size) # Log actual predicted words and labels self.logger.log_text("input/train", str([[self.word_vocab.reverse_lookup(st.item()) for st in token if st.item() != self.word_vocab[PAD_TOKEN] and st.item() != self.word_vocab[EOS_TOKEN]] for token in batch.tokens[0]])) self.logger.log_text("predicted words/train", str(self._decode_predicted_words(model_out, batch))) self.logger.log_text("labels/train", str(self._decode_labels(batch))) # Calculate metrics evaluation = self._evaluate_predictions(model_out.logits, batch.labels, loss=model_out.loss) self.logger.log_sub_batch_metrics(evaluation) if accumulate_tokens_batch: n_tokens_accumulate_batch += batch.sequence_lengths.sum().item() # Gradient accumulation: only update gradients every num_simulated_batches step if not accumulate_tokens_batch and it % num_simulated_batches == (num_simulated_batches - 1) \ or accumulate_tokens_batch and n_tokens_accumulate_batch > simulated_batch_size: if accumulate_tokens_batch: n_tokens_accumulate_batch = 0 train_step += 1 total_norm = 0 for p in self.model_lm.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) self.logger.log_metrics({'gradient_norm': total_norm}, train_step * simulated_batch_size) self.optimizer.step() self.optimizer.zero_grad() if self.scheduler: if not hasattr(self.scheduler, "total_steps") or train_step < self.scheduler.total_steps - 1: self.scheduler.step() self.logger.log_metrics({'lr': self.scheduler.get_lr()[0]}, train_step * simulated_batch_size) # Send train metrics to observers self.logger.flush_batch_metrics(train_step * simulated_batch_size) # Evaluate on validation set if self.use_validation and validate_every and train_step % validate_every == 0: t.measure() self.model_lm.eval() with torch.no_grad(): for validation_batch in islice(dataloader_validation, num_simulated_batches_valid): validation_batch = batch_filter_distances(validation_batch, self.relative_distances) validation_batch = batch_to_device(validation_batch, self.device) output = self.model_lm.forward_batch(validation_batch).cpu() validation_batch = batch_to_device(validation_batch, "cpu") evaluation = self._evaluate_predictions(output.logits, validation_batch.labels, loss=output.loss, partition='valid') self.logger.log_sub_batch_metrics(evaluation) self.logger.log_text("predicted words/validation", str(self._decode_predicted_words(output, validation_batch))) self.logger.log_text("labels/validation", str(self._decode_labels(validation_batch))) self.model_lm.train() self.logger.flush_batch_metrics(step=train_step * simulated_batch_size) self.logger.log_time(t.measure() / simulated_batch_size_valid, "valid_seconds/sample", train_step * simulated_batch_size) if persistent_snapshot_every and (it + 1) % persistent_snapshot_every == 0: snapshot_iteration = it + 1 self.logger.info(f"Storing model params into snapshot-{snapshot_iteration}") self.model_manager.save_snapshot(run_id, self.model_lm.state_dict(), snapshot_iteration) dataset = self.dataset_validation_creator(False) score = self.evaluate(islice(dataset.to_dataloader(), int(max_validation_samples / batch_size)), train_step * simulated_batch_size, 'valid_full') if f"micro_f1_score/valid_full" in self.logger.sub_batch_metrics: score_name = 'micro-F1' else: score_name = 'F1' self.logger.info(f"Full evaluation yielded {score} {score_name}") if not early_stopping.evaluate(score, snapshot_iteration): self.logger.info(f"Last {early_stopping_patience} evaluations did not improve performance. " f"Stopping run") break progress_bar.update() if progress_bar.n >= progress_bar.total: progress_bar = tqdm(total=int(self.data_manager.approximate_total_samples() / batch_size)) epoch += 1 progress_bar.set_description(f"Epoch {epoch}") t.measure() self._handle_shutdown()