def predict( input: Ignore[str], output: Ignore[str], log_prefix: Ignore[str], model: EncoderDecoder = None, batch_size_limit: int = 400, batch_limit_by_tokens: bool = True, ): logger = get_logger() (src_vocab, _), (src_field, tgt_field) = get_vocabularies() dataset = Corpora([src_field]) logger.info(f'{log_prefix}: Loading input file ...') with open(input) as src_stream: for src_sentence in src_stream: if src_sentence.strip(): dataset.append([src_sentence]) logger.info(f'{log_prefix}: Loading done.') if model is None: best_model_path = find_best_model() if best_model_path is None: raise RuntimeError( 'Model has not been trained yet. Train the model first.') model = build_model(src_field.vocabulary, tgt_field.vocabulary) state_dict = torch.load(best_model_path) model.load_state_dict(state_dict['model_state']) model.to(get_device()) with open(output, 'w') as output_stream, torch.no_grad(): for batch in dataset.iterate(get_device(), batch_size_limit, batch_limit_by_tokens, sort_by_length=False, shuffle=False): x_mask = batch[0] != src_vocab.pad_index x_mask = x_mask.unsqueeze(1) x_e = model.encode(batch[0], x_mask) y_hat, _ = beam_search(x_e, x_mask, model, get_scores=short_sent_penalty) sentence = src_field.to_sentence_str(batch[0][-1].tolist()) generated = tgt_field.to_sentence_str(y_hat[-1].tolist()) logger.info('SENTENCE:\n ---- {}'.format(sentence)) logger.info('GENERATED:\n ---- {}'.format(generated)) for generated in (src_field.to_sentence_str(s) for s in y_hat.tolist()): output_stream.write(f'{generated}\n')
def __init__(self, pad_index: Ignore[int], teacher_config_path: str = '/DOES_NOT_EXIST'): teacher_config_path = relative_to_config_path(teacher_config_path) assert os.path.exists( teacher_config_path), "Teacher model config does not exist." nn.Module.__init__(self) teacher_model_config = get_configuration().clone() with open(teacher_config_path) as f: teacher_model_config.load(json.load(f)) push_configuration(teacher_model_config) update_and_ensure_model_output_path('test', None) best_model_path = find_best_model() if best_model_path is None: raise ValueError('Could not find the teacher model.') (src_vocab, tgt_vocab), _ = get_vocabularies() self.teacher_model = build_model(src_vocab, tgt_vocab) state_dict = torch.load(best_model_path) self.teacher_model.load_state_dict(state_dict['model_state']) self.teacher_model.to(get_device()) self.teacher_model.eval() self.src_pad_index = src_vocab.pad_index self.tgt_pad_index = tgt_vocab.pad_index pop_configuration() self.pad_index = pad_index
def init_teacher_(model_emb, teacher_emb): assert model_emb.size(0) == teacher_emb.size( 0), "Vocabulary sizes are different." if model_emb.size(1) == teacher_emb.size(1): model_emb.copy_(teacher_emb.to(model_emb.device)) elif model_emb.size(1) > teacher_emb.size(1): model_emb.narrow(1, 0, teacher_emb.size(1)).copy_( teacher_emb.to(model_emb.device)) else: teacher_emb = teacher_emb.to(get_device()) print(teacher_emb.size()) print(torch.matmul(teacher_emb.t(), teacher_emb).size()) _, phi = torch.symeig(torch.matmul(teacher_emb.t(), teacher_emb), eigenvectors=True) model_emb.copy_( torch.matmul(teacher_emb, phi[:, :model_emb.size(1)]).to(model_emb.device))
def train(max_steps: int = 100, batch_size_limit: int = 400, batch_limit_by_tokens: bool = True, report_interval_steps: int = 10, validation_interval_steps: int = 100, lr_scheduler_at: str = 'every_step', n_ckpts_to_keep: int = 3, teacher_forcing: bool = True, random_seed: int = 42): set_random_seeds(random_seed) logger = get_logger() train_dataset = get_train_dataset() assert len( train_dataset.fields ) >= 2, "Train dataset must have at least two fields (source and target)." validation_dataset = get_validation_dataset() assert len( validation_dataset.fields ) >= 2, "Validation dataset must have at least two fields (source and target)." loss_function = get_loss_function( train_dataset.fields[1].vocabulary.pad_index) model = build_model(train_dataset.fields[0].vocabulary, train_dataset.fields[1].vocabulary) model.to(get_device()) loss_function.to(get_device()) optimizer = build_optimizer(model.parameters()) scheduler = build_scheduler(optimizer) initialize(model) def noop(): return None def step_lr_scheduler(): return scheduler.step() run_scheduler_at_step = noop run_scheduler_at_validation = noop run_scheduler_at_epoch = noop if scheduler is not None: if lr_scheduler_at == 'every_step': run_scheduler_at_step = step_lr_scheduler elif lr_scheduler_at == 'every_validation': run_scheduler_at_validation = step_lr_scheduler elif lr_scheduler_at == 'every_epoch': run_scheduler_at_epoch = step_lr_scheduler step = 0 epoch = 0 kept_checkpoint_path_score_map = {} best_checkpoint_specs = {"score": -math.inf, "step": -1} @configured('model') def maybe_save_checkpoint(score: Ignore[float], output_path: str): if len(kept_checkpoint_path_score_map) < n_ckpts_to_keep or \ any(score > s for s in kept_checkpoint_path_score_map.values()): if len(kept_checkpoint_path_score_map) >= n_ckpts_to_keep: worst_checkpoint_path = sorted( kept_checkpoint_path_score_map.keys(), key=lambda p: kept_checkpoint_path_score_map[p], reverse=False) worst_checkpoint_path = worst_checkpoint_path[0] kept_checkpoint_path_score_map.pop(worst_checkpoint_path) try: os.unlink(worst_checkpoint_path) except: logger.warn( 'Could not unlink {}.'.format(worst_checkpoint_path)) if score > best_checkpoint_specs["score"]: logger.info( 'New `best model` found with score {:.3f} at step {}.'. format(score, step)) best_checkpoint_specs["score"] = score best_checkpoint_specs["step"] = step state_dict = { "step": step, "best_checkpoint_specs": best_checkpoint_specs, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict() if scheduler is not None else None } checkpoint_path = '{}/step_{}_score_{:.3f}.pt'.format( output_path, step, score) torch.save(state_dict, checkpoint_path) kept_checkpoint_path_score_map[checkpoint_path] = score model.train() validation_done_already = False while step < max_steps: start_time = time.time() total_tokens_processed = 0 for batch in train_dataset.iterate(get_device(), batch_size_limit, batch_limit_by_tokens): step += 1 if step >= max_steps: break x_mask = batch[0] != model.src_vocab.pad_index x_mask = x_mask.unsqueeze(1) y_mask = batch[1] != model.tgt_vocab.pad_index y_mask = y_mask.unsqueeze(1) x_e = model.encode(batch[0], x_mask) log_probs = model.decode(batch[1][:, :-1], x_e, y_mask[:, :, :-1], x_mask, teacher_forcing=teacher_forcing) token_count = y_mask[:, :, 1:].sum().item() loss = loss_function(log_probs, batch[1][:, 1:], model.get_target_embeddings()) / token_count loss.backward() optimizer.step() mark_optimization_step() optimizer.zero_grad() run_scheduler_at_step() total_tokens_processed += token_count if step > 0 and step % report_interval_steps == 0: elapsed_time = time.time() - start_time baseline_loss = loss_function.uniform_baseline_loss( log_probs, batch[1][:, 1:]) logger.info( 'Epoch_{} Step_{}: loss={:.3f}(vs {:.3f} uniform), tokens/s={:.1f}, lr={}' .format(epoch, step, loss.item(), baseline_loss, total_tokens_processed / elapsed_time, optimizer.param_groups[0]['lr'])) start_time = time.time() total_tokens_processed = 0 if step > 0 and step % validation_interval_steps == 0: log_prefix = 'Epoch_{} Step_{}'.format(epoch, step) score = evaluate(validation_dataset, log_prefix, model, loss_function) maybe_save_checkpoint(score) model.train() run_scheduler_at_validation() start_time = time.time() total_tokens_processed = 0 validation_done_already = True else: validation_done_already = False epoch += 1 logger.info('Epoch {} finished.'.format(epoch)) run_scheduler_at_epoch() if not validation_done_already: log_prefix = 'Final (epoch={} ~ step={})'.format(epoch, step) score = evaluate(validation_dataset, log_prefix, model, loss_function) maybe_save_checkpoint(score) logger.info('Best validation loss was {:.3f} at step {}.'.format( best_checkpoint_specs["score"], best_checkpoint_specs["step"]))
def evaluate(validation_dataset: Corpora, log_prefix: Ignore[str], model: EncoderDecoder = None, loss_function: Callable = None, batch_size_limit: int = 400, batch_limit_by_tokens: bool = True, teacher_forcing: bool = True, metrics: Tuple[Metric] = None): assert len( validation_dataset.fields ) >= 2, "Validation dataset must have at least two fields (source and target)." logger = get_logger() if loss_function is None: loss_function = get_loss_function( validation_dataset.fields[1].vocabulary.pad_index) loss_function.to(get_device()) if model is None: best_model_path = find_best_model() if best_model_path is None: raise RuntimeError( 'Model has not been trained yet. Train the model first.') model = build_model(validation_dataset.fields[0].vocabulary, validation_dataset.fields[1].vocabulary) state_dict = torch.load(best_model_path) model.load_state_dict(state_dict['model_state']) model.to(get_device()) pad_index = model.tgt_vocab.pad_index total_item_count = 0 total_validation_loss = 0 model.eval() printed_samples = 0 if metrics is None: metrics = (BleuMetric(), ) else: metrics = (BleuMetric(), ) + tuple( m for m in metrics if not isinstance(m, BleuMetric)) with torch.no_grad(): start_time = time.time() for validation_batch in validation_dataset.iterate( get_device(), batch_size_limit, batch_limit_by_tokens, sort_by_length=False, shuffle=False): x_mask = validation_batch[0] != model.src_vocab.pad_index x_mask = x_mask.unsqueeze(1) y_mask = validation_batch[1] != model.tgt_vocab.pad_index y_mask = y_mask.unsqueeze(1) x_e = model.encode(validation_batch[0], x_mask) log_probs = model.decode(validation_batch[1][:, :-1], x_e, y_mask[:, :, :-1], x_mask, teacher_forcing=teacher_forcing) loss = loss_function(log_probs, validation_batch[1][:, 1:], model.get_target_embeddings()) total_item_count += y_mask[:, :, 1:].sum().item() total_validation_loss += loss.item() y_hat, _ = beam_search(x_e, x_mask, model, get_scores=short_sent_penalty) if printed_samples < 4: sentence = validation_dataset.fields[0].to_sentence_str( validation_batch[0][-1].tolist()) reference = validation_dataset.fields[1].to_sentence_str( validation_batch[1][-1].tolist()) generated = validation_dataset.fields[1].to_sentence_str( y_hat[-1].tolist()) logger.info('SENTENCE:\n ---- {}'.format(sentence)) logger.info('REFERENCE:\n ---- {}'.format(reference)) logger.info('GENERATED:\n ---- {}'.format(generated)) printed_samples += 1 update_metric_params(y_hat, validation_batch[1], pad_index, metrics) elapsed_time = time.time() - start_time logger.info( f'{log_prefix}: ' f'evaluation_loss={total_validation_loss / total_item_count:.3f}, ' f'elapsed_time={int(elapsed_time + 0.5)}s') for metric_repr in (str(m) for m in metrics): logger.info(f'{log_prefix}: evaluation {metric_repr}') return metrics[0].get_score()