def __init__(self, args): super(Extractor, self).__init__() config = getattr(configurations, args.proto)() self.logger = ut.get_logger(config['log_file']) self.model_file = args.model_file var_list = args.var_list save_to = args.save_to if var_list is None: raise ValueError('Empty var list') if self.model_file is None or not os.path.exists(self.model_file): raise ValueError('Input file or model file does not exist') if not os.path.exists(save_to): os.makedirs(save_to) self.logger.info('Extracting these vars: {}'.format( ', '.join(var_list))) model = Model(config) model.load_state_dict(torch.load(self.model_file)) var_values = operator.attrgetter(*var_list)(model) if len(var_list) == 1: var_values = [var_values] for var, var_value in zip(var_list, var_values): var_path = os.path.join(save_to, var + '.npy') numpy.save(var_path, var_value.numpy())
def create_train_model(): train_graph = tf.Graph() mode = tf.contrib.learn.ModeKeys.TRAIN train_model = Model(mode, hyper_parameters) dataset_iterator = DatasetIterator(hyper_parameters) train = Train(mode, hyper_parameters) with train_graph.as_default(): source_vocab_table, target_vocab_table = dataset_iterator.get_tables( share_vocab=False) source_dataset, target_dataset = dataset_iterator.get_datasets() train_iterator = dataset_iterator.get_iterator( source_vocab_table=source_vocab_table, target_vocab_table=target_vocab_table, source_dataset=source_dataset, target_dataset=target_dataset, source_max_len=hyper_parameters["source_max_len_train"], target_max_len=hyper_parameters["target_max_len_train"], #skip_count=skip_count_place_holder #todo probably we need this ) logits, loss, final_context_state, sample_id = train_model.build_model( train_iterator, target_vocab_table) train.configure_train_eval_infer(iterator=train_iterator, logits=logits, loss=loss, sample_id=sample_id, final_state=final_context_state) return train_graph, train_iterator, train
def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.normalize_loss = self.config['normalize_loss'] self.patience = self.config['patience'] self.lr = self.config['lr'] self.lr_decay = self.config['lr_decay'] self.max_epochs = self.config['max_epochs'] self.warmup_steps = self.config['warmup_steps'] self.train_smooth_perps = [] self.train_true_perps = [] self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.val_per_epoch = self.config['val_per_epoch'] self.validate_freq = int(self.config['validate_freq']) self.logger.info('Evaluate every {} {}'.format( self.validate_freq, 'epochs' if self.val_per_epoch else 'batches')) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_nll_loss = 0. self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model self.model = Model(self.config).to(self.device) param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info('Model has {:,} parameters'.format(param_count)) # get optimizer beta1 = self.config['beta1'] beta2 = self.config['beta2'] epsilon = self.config['epsilon'] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(beta1, beta2), eps=epsilon)
def translate(self): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = Model(self.config).to(device) self.logger.info('Restore model from {}'.format(self.model_file)) model.load_state_dict(torch.load(self.model_file)) model.eval() best_trans_file = self.input_file + '.best_trans' beam_trans_file = self.input_file + '.beam_trans' open(best_trans_file, 'w').close() open(beam_trans_file, 'w').close() num_sents = 0 with open(self.input_file, 'r') as f: for line in f: if line.strip(): num_sents += 1 all_best_trans = [''] * num_sents all_beam_trans = [''] * num_sents with torch.no_grad(): self.logger.info('Start translating {}'.format(self.input_file)) start = time.time() count = 0 for (src_toks, original_idxs) in self.data_manager.get_trans_input( self.input_file): src_toks_cuda = src_toks.to(device) rets = model.beam_decode(src_toks_cuda) for i, ret in enumerate(rets): probs = ret['probs'].cpu().detach().numpy().reshape([-1]) scores = ret['scores'].cpu().detach().numpy().reshape([-1]) symbols = ret['symbols'].cpu().detach().numpy() best_trans, best_trans_ids, beam_trans = self.get_trans( probs, scores, symbols) all_best_trans[original_idxs[i]] = best_trans + '\n' all_beam_trans[original_idxs[i]] = beam_trans + '\n\n' count += 1 if count % 100 == 0: self.logger.info( ' Translating line {}, average {} seconds/sent'. format(count, (time.time() - start) / count)) model.train() with open(best_trans_file, 'w') as ftrans, open(beam_trans_file, 'w') as btrans: ftrans.write(''.join(all_best_trans)) btrans.write(''.join(all_beam_trans)) self.logger.info('Done translating {}, it takes {} minutes'.format( self.input_file, float(time.time() - start) / 60.0))
def get_model(self, mode): reuse = mode != ac.TRAINING d = self.config['init_range'] initializer = tf.random_uniform_initializer(-d, d) with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer): return Model(self.config, mode)
def create_eval_model(): eval_graph = tf.Graph() mode = tf.contrib.learn.ModeKeys.EVAL eval_model = Model(mode, hyper_parameters) dataset_iterator = DatasetIterator(hyper_parameters) eval = Train(mode, hyper_parameters) with eval_graph.as_default(): source_vocab_file = hyper_parameters["source_vocab_file"] target_vocab_file = hyper_parameters["target_vocab_file"] source_vocab_table, target_vocab_table = dataset_iterator.get_tables( share_vocab=False) reverse_target_vocab_table = tf.contrib.lookup.index_to_string_table_from_file( target_vocab_file, default_value="UNK") source_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) target_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) source_dataset = tf.data.TextLineDataset(source_file_placeholder) target_dataset = tf.data.TextLineDataset(target_file_placeholder) eval_iterator = dataset_iterator.get_iterator( source_vocab_table=source_vocab_table, target_vocab_table=target_vocab_table, source_dataset=source_dataset, target_dataset=target_dataset, source_max_len=hyper_parameters["source_max_len_infer"], target_max_len=hyper_parameters["target_max_len_infer"]) logits, loss, final_context_state, sample_id = eval_model.build_model( eval_iterator, target_vocab_table) eval.configure_train_eval_infer( iterator=eval_iterator, logits=logits, loss=loss, sample_id=sample_id, final_state=final_context_state, reverse_target_vocab_table=reverse_target_vocab_table) return eval_graph, eval, eval_iterator, source_file_placeholder, target_file_placeholder
def create_infer_model(): infer_graph = tf.Graph() mode = tf.contrib.learn.ModeKeys.INFER infer_model = Model(mode, hyper_parameters) dataset_iterator = DatasetIterator(hyper_parameters) infer = Train(mode, hyper_parameters) with infer_graph.as_default(): source_vocab_file = hyper_parameters["source_vocab_file"] target_vocab_file = hyper_parameters["target_vocab_file"] reverse_target_vocab_table = tf.contrib.lookup.index_to_string_table_from_file( target_vocab_file, default_value="UNK") source_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) source_dataset = tf.data.Dataset.from_tensor_slices(source_placeholder) source_vocab_table, target_vocab_table = dataset_iterator.get_tables( share_vocab=False) infer_iterator = dataset_iterator.get_infer_iterator( source_dataset=source_dataset, source_vocab_table=source_vocab_table, source_max_len=hyper_parameters["source_max_len_infer"]) logits, loss, final_context_state, sample_id = infer_model.build_model( infer_iterator, target_vocab_table) infer.configure_train_eval_infer( iterator=infer_iterator, logits=logits, loss=loss, sample_id=sample_id, final_state=final_context_state, reverse_target_vocab_table=reverse_target_vocab_table) return infer_graph, infer, infer_iterator, source_placeholder, batch_size_placeholder
def __init__(self, args): super(Translator, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.logger = ut.get_logger(self.config['log_file']) self.num_preload = args.num_preload self.model_file = args.model_file if self.model_file is None: self.model_file = os.path.join(self.config['save_to'], self.config['model_name'] + '.pth') self.input_file = args.input_file if self.input_file is not None and not os.path.exists(self.input_file): raise FileNotFoundError( f'Input file does not exist: {self.input_file}') if not os.path.exists(self.model_file): raise FileNotFoundError( f'Model file does not exist: {self.model_file}') self.logger.info(f'Restore model from {self.model_file}') self.model = Model(self.config, load_from=self.model_file).to(ut.get_device()) if self.input_file: save_fp = os.path.join(self.config['save_to'], os.path.basename(self.input_file)) save_fp = save_fp.rstrip(self.model.data_manager.src_lang) save_fp = save_fp + self.model.data_manager.trg_lang self.best_output_fp = save_fp + '.best_trans' self.beam_output_fp = save_fp + '.beam_trans' open(self.best_output_fp, 'w').close() open(self.beam_output_fp, 'w').close() else: self.best_output_fp = self.beam_output_fp = None self.translate()
def __init__(self, args): super(Extractor, self).__init__() config = getattr(configurations, args.proto)() self.logger = ut.get_logger(config['log_file']) self.model_file = args.model_file var_list = args.var_list save_to = args.save_to if var_list is None: raise ValueError('Empty var list') if self.model_file is None or not os.path.exists(self.model_file + '.meta'): raise ValueError('Input file or model file does not exist') if not os.path.exists(save_to): os.makedirs(save_to) self.logger.info('Extracting these vars: {}'.format(', '.join(var_list))) with tf.Graph().as_default(), tf.Session() as sess: d = config['init_range'] initializer = tf.random_uniform_initializer(-d, d) with tf.variable_scope(config['model_name'], reuse=False, initializer=initializer): model = Model(config, ac.TRAINING) saver = tf.train.Saver(var_list=tf.trainable_variables()) saver.restore(sess, self.model_file) var_values = operator.attrgetter(*var_list)(model) var_values = sess.run(var_values) if len(var_list) == 1: var_values = [var_values] for var, var_value in izip(var_list, var_values): var_path = os.path.join(save_to, var + '.npy') numpy.save(var_path, var_value)
def __init__(self, model: Model, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = make_model_dir(train_config["model_dir"], overwrite=train_config.get( "overwrite", False)) self.logger = make_logger(model_dir=self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = "{}/validations.txt".format(self.model_dir) self.tb_writer = SummaryWriter(log_dir=self.model_dir + "/tensorboard/") # model self.model = model self.pad_index = self.model.pad_index self.bos_index = self.model.bos_index self._log_parameters_list() # objective self.label_smoothing = train_config.get("label_smoothing", 0.0) self.loss = XentLoss(pad_index=self.pad_index, smoothing=self.label_smoothing) self.normalization = train_config.get("normalization", "batch") if self.normalization not in ["batch", "tokens"]: raise ConfigurationError("Invalid normalization. " "Valid options: 'batch', 'tokens'.") # optimization self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer(config=train_config, parameters=model.parameters()) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 1000) self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2]) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) self.eval_metric = train_config.get("eval_metric", "bleu") if self.eval_metric not in ['bleu', 'chrf']: raise ConfigurationError("Invalid setting for 'eval_metric', " "valid options: 'bleu', 'chrf'.") self.early_stopping_metric = train_config.get("early_stopping_metric", "eval_metric") # if we schedule after BLEU/chrf, we want to maximize it, else minimize # early_stopping_metric decides on how to find the early stopping point: # ckpts are written when there's a new high/low score for this metric if self.early_stopping_metric in ["ppl", "loss"]: self.minimize_metric = True elif self.early_stopping_metric == "eval_metric": if self.eval_metric in ["bleu", "chrf"]: self.minimize_metric = False else: # eval metric that has to get minimized (not yet implemented) self.minimize_metric = True else: raise ConfigurationError( "Invalid setting for 'early_stopping_metric', " "valid options: 'loss', 'ppl', 'eval_metric'.") # learning rate scheduling self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=config["model"]["encoder"]["hidden_size"]) # data & batch handling self.level = config["data"]["level"] if self.level not in ["word", "bpe", "char"]: raise ConfigurationError("Invalid segmentation level. " "Valid options: 'word', 'bpe', 'char'.") self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.batch_multiplier = train_config.get("batch_multiplier", 1) # generation self.max_output_length = train_config.get("max_output_length", None) # CPU / GPU self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() self.loss.cuda() # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf # comparison function for scores self.is_best = lambda score: score < self.best_ckpt_score \ if self.minimize_metric else score > self.best_ckpt_score # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] self.logger.info("Loading model from %s", model_load_path) self.init_from_checkpoint(model_load_path)
def __init__(self, args): super(Trainer, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.num_preload = args.num_preload self.lr = self.config['lr'] ut.remove_files_in_dir(self.config['save_to']) self.logger = ut.get_logger(self.config['log_file']) self.train_smooth_perps = [] self.train_true_perps = [] # For logging self.log_freq = self.config[ 'log_freq'] # log train stat every this-many batches self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] self.total_batches = 0 # number of batches done for the whole training self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model device = ut.get_device() self.model = Model(self.config).to(device) self.validator = Validator(self.config, self.model) self.validate_freq = self.config['validate_freq'] if self.validate_freq == 1: self.logger.info('Evaluate every ' + ( 'epoch' if self.config['val_per_epoch'] else 'batch')) else: self.logger.info(f'Evaluate every {self.validate_freq:,} ' + ( 'epochs' if self.config['val_per_epoch'] else 'batches')) # Estimated number of batches per epoch self.est_batches = max(self.model.data_manager.training_tok_counts ) // self.config['batch_size'] self.logger.info( f'Guessing around {self.est_batches:,} batches per epoch') param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info(f'Model has {int(param_count):,} parameters') # Set up parameter-specific options params = [] for p in self.model.parameters(): ptr = p.data_ptr() d = {'params': [p]} if ptr in self.model.parameter_attrs: attrs = self.model.parameter_attrs[ptr] for k in attrs: d[k] = attrs[k] params.append(d) self.optimizer = torch.optim.Adam(params, lr=self.lr, betas=(self.config['beta1'], self.config['beta2']), eps=self.config['epsilon'])
class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.num_preload = args.num_preload self.lr = self.config['lr'] ut.remove_files_in_dir(self.config['save_to']) self.logger = ut.get_logger(self.config['log_file']) self.train_smooth_perps = [] self.train_true_perps = [] # For logging self.log_freq = self.config[ 'log_freq'] # log train stat every this-many batches self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] self.total_batches = 0 # number of batches done for the whole training self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model device = ut.get_device() self.model = Model(self.config).to(device) self.validator = Validator(self.config, self.model) self.validate_freq = self.config['validate_freq'] if self.validate_freq == 1: self.logger.info('Evaluate every ' + ( 'epoch' if self.config['val_per_epoch'] else 'batch')) else: self.logger.info(f'Evaluate every {self.validate_freq:,} ' + ( 'epochs' if self.config['val_per_epoch'] else 'batches')) # Estimated number of batches per epoch self.est_batches = max(self.model.data_manager.training_tok_counts ) // self.config['batch_size'] self.logger.info( f'Guessing around {self.est_batches:,} batches per epoch') param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info(f'Model has {int(param_count):,} parameters') # Set up parameter-specific options params = [] for p in self.model.parameters(): ptr = p.data_ptr() d = {'params': [p]} if ptr in self.model.parameter_attrs: attrs = self.model.parameter_attrs[ptr] for k in attrs: d[k] = attrs[k] params.append(d) self.optimizer = torch.optim.Adam(params, lr=self.lr, betas=(self.config['beta1'], self.config['beta2']), eps=self.config['epsilon']) def report_epoch(self, epoch, batches): self.logger.info(f'Finished epoch {epoch}') self.logger.info(f' Took {ut.format_time(self.epoch_time)}') self.logger.info( f' avg words/sec {self.epoch_weights / self.epoch_time:.2f}') self.logger.info(f' avg sec/batch {self.epoch_time / batches:.2f}') self.logger.info(f' {batches} batches') if self.epoch_weights: train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights else: train_smooth_perp = float('inf') train_true_perp = float('inf') self.est_batches = batches self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( f' smooth, true perp: {float(train_smooth_perp):.2f}, {float(train_true_perp):.2f}' ) def clip_grad_values(self): """ Adapted from https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_value_ This is the same as torch.nn.utils.clip_grad_value_, except is also sets nan gradients to 0.0 """ parameters = self.model.parameters() clip_value = float(self.config['grad_clamp']) if isinstance(parameters, torch.Tensor): parameters = [parameters] for p in filter(lambda p: p.grad is not None, parameters): p.grad.data.clamp_(min=-clip_value, max=clip_value) p.grad.data[torch.isnan(p.grad.data)] = 0.0 def get_params(self, pe=False): for n, p in self.model.named_parameters(): if (n in self.model.struct_params) == pe: yield p def run_log(self, batch, epoch, batch_data): #with torch.autograd.detect_anomaly(): # throws exception when any forward computation produces nan start = time.time() _, src_toks, src_structs, trg_toks, targets = batch_data # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks, src_structs, trg_toks, targets, batch, epoch) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.config['normalize_loss'] == ac.LOSS_TOK: opt_loss = loss / (targets != ac.PAD_ID).sum() elif self.config['normalize_loss'] == ac.LOSS_BATCH: opt_loss = loss / targets.size()[0] else: opt_loss = loss opt_loss.backward() # clip gradient if self.config['grad_clamp']: self.clip_grad_values() if self.config['grad_clip_pe']: pms = list(self.get_params(True)) if pms: torch.nn.utils.clip_grad_norm_(pms, self.config['grad_clip_pe']) pms = self.get_params() else: pms = self.model.parameters() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config['grad_clip']).detach() # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().sum() loss = loss.detach() nll_loss = nll_loss.detach() self.total_batches += 1 self.log_train_loss.append(loss) self.log_nll_loss.append(nll_loss) self.log_train_weights.append(num_words) self.log_grad_norms.append(grad_norm) self.epoch_time += time.time() - start if self.total_batches % self.log_freq == 0: log_train_loss = torch.tensor(0.0) log_nll_loss = torch.tensor(0.0) log_train_weights = torch.tensor(0.0) log_all_weights = torch.tensor(0.0) for smooth, nll, weight in zip(self.log_train_loss, self.log_nll_loss, self.log_train_weights): if not self.config['grad_clamp'] or (torch.isfinite(smooth) and torch.isfinite(nll)): log_train_loss += smooth log_nll_loss += nll log_train_weights += weight log_all_weights += weight #log_train_loss = sum(x for x in self.log_train_loss).item() #log_nll_loss = sum(x for x in self.log_nll_loss).item() #log_train_weights = sum(x for x in self.log_train_weights).item() avg_smooth_perp = log_train_loss / log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = log_nll_loss / log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.epoch_loss += log_train_loss self.epoch_nll_loss += log_nll_loss self.epoch_weights += log_all_weights acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / batch avg_grad_norm = sum(self.log_grad_norms) / len(self.log_grad_norms) #median_grad_norm = sorted(self.log_grad_norms)[len(self.log_grad_norms)//2] est_percent = int(100 * batch / self.est_batches) epoch_len = max(5, ut.get_num_digits(self.config['max_epochs'])) batch_len = max(5, ut.get_num_digits(self.est_batches)) if batch > self.est_batches: remaining = '?' else: remaining = ut.format_time(acc_speed_time * (self.est_batches - batch)) self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] cells = [ f'{epoch:{epoch_len}}', f'{batch:{batch_len}}', f'{est_percent:3}%', f'{remaining:>9}', f'{acc_speed_word:#10.4g}', f'{acc_speed_time:#6.4g}s', f'{avg_smooth_perp:#11.4g}', f'{avg_true_perp:#9.4g}', f'{avg_grad_norm:#9.4g}' ] self.logger.info(' '.join(cells)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.total_batches + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr elif self.config['warmup_style'] == ac.FIXED_WARMUP: warmup_steps = self.config['warmup_steps'] step = self.total_batches + 1.0 start_lr = self.config['start_lr'] peak_lr = self.config['lr'] min_lr = self.config['min_lr'] if step < warmup_steps: lr = start_lr + (peak_lr - start_lr) * step / warmup_steps else: lr = max(min_lr, peak_lr * warmup_steps**(0.5) * step**(-0.5)) for p in self.optimizer.param_groups: p['lr'] = lr elif self.config['warmup_style'] == ac.UPFLAT_WARMUP: warmup_steps = self.config['warmup_steps'] step = self.total_batches + 1.0 start_lr = self.config['start_lr'] peak_lr = self.config['lr'] min_lr = self.config['min_lr'] if step < warmup_steps: lr = start_lr + (peak_lr - start_lr) * step / warmup_steps for p in self.optimizer.param_groups: p['lr'] = lr else: pass def train(self): self.model.train() stop_early = False early_stop_msg_num = self.config[ 'early_stop_patience'] * self.validate_freq early_stop_msg_metric = 'epochs' if self.config[ 'val_by_bleu'] else 'batches' early_stop_msg = f'No improvement for last {early_stop_msg_num} {early_stop_msg_metric}; stopping early!' for epoch in range(1, self.config['max_epochs'] + 1): batch = 0 for batch_data in self.model.data_manager.get_batches( mode=ac.TRAINING, num_preload=self.num_preload): if batch == 0: self.logger.info(f'Begin epoch {epoch}') epoch_str = ' ' * max( 0, ut.get_num_digits(self.config['max_epochs']) - 5) + 'epoch' batch_str = ' ' * max( 0, ut.get_num_digits(self.est_batches) - 5) + 'batch' self.logger.info(' '.join([ epoch_str, batch_str, 'est%', 'remaining', 'trg word/s', 's/batch', 'smooth perp', 'true perp', 'grad norm' ])) batch += 1 self.run_log(batch, epoch, batch_data) if not self.config['val_per_epoch']: stop_early = self.maybe_validate() if stop_early: self.logger.info(early_stop_msg) break if stop_early: break self.report_epoch(epoch, batch) if self.config['val_per_epoch'] and epoch % self.validate_freq == 0: stop_early = self.maybe_validate(just_validate=True) if stop_early: self.logger.info(early_stop_msg) break if not self.config['val_by_bleu'] and not stop_early: # validate 1 last time self.maybe_validate(just_validate=True) self.logger.info('Training finished') self.logger.info('Train smooth perps:') self.logger.info(', '.join( [f'{x:.2f}' for x in self.train_smooth_perps])) self.logger.info('Train true perps:') self.logger.info(', '.join([f'{x:.2f}' for x in self.train_true_perps])) numpy.save( os.path.join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save( os.path.join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.model.save() # Evaluate test test_file = self.model.data_manager.data_files[ac.TESTING][ self.model.data_manager.src_lang] dev_file = self.model.data_manager.data_files[ac.VALIDATING][ self.model.data_manager.src_lang] if os.path.exists(test_file): self.logger.info('Evaluate test') self.restart_to_best_checkpoint() self.model.save() self.validator.translate(test_file, to_ids=True) self.logger.info('Translate dev set') self.validator.translate(dev_file, to_ids=True) def restart_to_best_checkpoint(self): if self.config['val_by_bleu']: best_bleu = numpy.max(self.validator.best_bleus) best_cpkt_path = self.validator.get_cpkt_path(best_bleu) else: best_perp = numpy.min(self.validator.best_perps) best_cpkt_path = self.validator.get_cpkt_path(best_perp) self.logger.info(f'Restore best cpkt from {best_cpkt_path}') self.model.load_state_dict(torch.load(best_cpkt_path)) def is_patience_exhausted(self, patience, if_worst=False): ''' if_worst=False (default) -> check if last patience epochs have failed to improve dev score if_worst=True -> check if last epoch was WORSE than the patience epochs before it ''' curve = self.validator.bleu_curve if self.config[ 'val_by_bleu'] else self.validator.perp_curve best_worse = max if self.config['val_by_bleu'] is not if_worst else min return patience and len( curve) > patience and curve[-1 if if_worst else -1 - patience] == best_worse( curve[-1 - patience:]) def maybe_validate(self, just_validate=False): if self.total_batches % self.validate_freq == 0 or just_validate: self.model.save() self.validator.validate_and_save() # if doing annealing step = self.total_batches + 1.0 warmup_steps = self.config['warmup_steps'] if self.config['warmup_style'] == ac.NO_WARMUP \ or (self.config['warmup_style'] == ac.UPFLAT_WARMUP and step >= warmup_steps) \ and self.config['lr_decay'] > 0: if self.is_patience_exhausted(self.config['lr_decay_patience'], if_worst=True): if self.config['val_by_bleu']: metric = 'bleu' scores = self.validator.bleu_curve else: metric = 'perp' scores = self.validator.perp_curve scores = ', '.join([ str(x) for x in scores[-1 - self.config['lr_decay_patience']:] ]) self.logger.info(f'Past {metric} scores are {scores}') # when don't use warmup, decay lr if dev not improve if self.lr * self.config['lr_decay'] >= self.config[ 'min_lr']: new_lr = self.lr * self.config['lr_decay'] self.logger.info( f'Anneal the learning rate from {self.lr} to {new_lr}' ) self.lr = new_lr for p in self.optimizer.param_groups: p['lr'] = self.lr return self.is_patience_exhausted(self.config['early_stop_patience'])
class Trainer(object): """Trainer""" def __init__(self, args): super(Trainer, self).__init__() self.config = getattr(configurations, args.proto)() self.num_preload = args.num_preload self.logger = ut.get_logger(self.config['log_file']) self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.normalize_loss = self.config['normalize_loss'] self.patience = self.config['patience'] self.lr = self.config['lr'] self.lr_decay = self.config['lr_decay'] self.max_epochs = self.config['max_epochs'] self.warmup_steps = self.config['warmup_steps'] self.train_smooth_perps = [] self.train_true_perps = [] self.data_manager = DataManager(self.config) self.validator = Validator(self.config, self.data_manager) self.val_per_epoch = self.config['val_per_epoch'] self.validate_freq = int(self.config['validate_freq']) self.logger.info('Evaluate every {} {}'.format( self.validate_freq, 'epochs' if self.val_per_epoch else 'batches')) # For logging self.log_freq = 100 # log train stat every this-many batches self.log_train_loss = 0. # total train loss every log_freq batches self.log_nll_loss = 0. self.log_train_weights = 0. self.num_batches_done = 0 # number of batches done for the whole training self.epoch_batches_done = 0 # number of batches done for this epoch self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model self.model = Model(self.config).to(self.device) param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info('Model has {:,} parameters'.format(param_count)) # get optimizer beta1 = self.config['beta1'] beta2 = self.config['beta2'] epsilon = self.config['epsilon'] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(beta1, beta2), eps=epsilon) def report_epoch(self, e): self.logger.info('Finish epoch {}'.format(e)) self.logger.info(' It takes {}'.format( ut.format_seconds(self.epoch_time))) self.logger.info(' Avergage # words/second {}'.format( self.epoch_weights / self.epoch_time)) self.logger.info(' Average seconds/batch {}'.format( self.epoch_time / self.epoch_batches_done)) train_smooth_perp = self.epoch_loss / self.epoch_weights train_true_perp = self.epoch_nll_loss / self.epoch_weights self.epoch_batches_done = 0 self.epoch_time = 0. self.epoch_nll_loss = 0. self.epoch_loss = 0. self.epoch_weights = 0. train_smooth_perp = numpy.exp( train_smooth_perp) if train_smooth_perp < 300 else float('inf') self.train_smooth_perps.append(train_smooth_perp) train_true_perp = numpy.exp( train_true_perp) if train_true_perp < 300 else float('inf') self.train_true_perps.append(train_true_perp) self.logger.info( ' smoothed train perplexity: {}'.format(train_smooth_perp)) self.logger.info( ' true train perplexity: {}'.format(train_true_perp)) def run_log(self, b, e, batch_data): start = time.time() src_toks, trg_toks, targets = batch_data src_toks_cuda = src_toks.to(self.device) trg_toks_cuda = trg_toks.to(self.device) targets_cuda = targets.to(self.device) # zero grad self.optimizer.zero_grad() # get loss ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda) loss = ret['loss'] nll_loss = ret['nll_loss'] if self.normalize_loss == ac.LOSS_TOK: opt_loss = loss / (targets_cuda != ac.PAD_ID).type( loss.type()).sum() elif self.normalize_loss == ac.LOSS_BATCH: opt_loss = loss / targets_cuda.size()[0].type(loss.type()) else: opt_loss = loss opt_loss.backward() # clip gradient global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip']) # update self.adjust_lr() self.optimizer.step() # update training stats num_words = (targets != ac.PAD_ID).detach().numpy().sum() loss = loss.cpu().detach().numpy() nll_loss = nll_loss.cpu().detach().numpy() self.num_batches_done += 1 self.log_train_loss += loss self.log_nll_loss += nll_loss self.log_train_weights += num_words self.epoch_batches_done += 1 self.epoch_loss += loss self.epoch_nll_loss += nll_loss self.epoch_weights += num_words self.epoch_time += time.time() - start if self.num_batches_done % self.log_freq == 0: acc_speed_word = self.epoch_weights / self.epoch_time acc_speed_time = self.epoch_time / self.epoch_batches_done avg_smooth_perp = self.log_train_loss / self.log_train_weights avg_smooth_perp = numpy.exp( avg_smooth_perp) if avg_smooth_perp < 300 else float('inf') avg_true_perp = self.log_nll_loss / self.log_train_weights avg_true_perp = numpy.exp( avg_true_perp) if avg_true_perp < 300 else float('inf') self.log_train_loss = 0. self.log_nll_loss = 0. self.log_train_weights = 0. self.logger.info('Batch {}, epoch {}/{}:'.format( b, e + 1, self.max_epochs)) self.logger.info( ' avg smooth perp: {0:.2f}'.format(avg_smooth_perp)) self.logger.info( ' avg true perp: {0:.2f}'.format(avg_true_perp)) self.logger.info(' acc trg words/s: {}'.format( int(acc_speed_word))) self.logger.info( ' acc sec/batch: {0:.2f}'.format(acc_speed_time)) self.logger.info(' global norm: {0:.2f}'.format(global_norm)) def adjust_lr(self): if self.config['warmup_style'] == ac.ORG_WARMUP: step = self.num_batches_done + 1.0 if step < self.config['warmup_steps']: lr = self.config['embed_dim']**( -0.5) * step * self.config['warmup_steps']**(-1.5) else: lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5), self.config['min_lr']) for p in self.optimizer.param_groups: p['lr'] = lr def train(self): self.model.train() train_ids_file = self.data_manager.data_files['ids'] for e in range(self.max_epochs): b = 0 for batch_data in self.data_manager.get_batch( ids_file=train_ids_file, shuffle=True, num_preload=self.num_preload): b += 1 self.run_log(b, e, batch_data) if not self.val_per_epoch: self.maybe_validate() self.report_epoch(e + 1) if self.val_per_epoch and (e + 1) % self.validate_freq == 0: self.maybe_validate(just_validate=True) # validate 1 last time if not self.config['val_per_epoch']: self.maybe_validate(just_validate=True) self.logger.info('It is finally done, mate!') self.logger.info('Train smoothed perps:') self.logger.info(', '.join(map(str, self.train_smooth_perps))) self.logger.info('Train true perps:') self.logger.info(', '.join(map(str, self.train_true_perps))) numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'), self.train_smooth_perps) numpy.save(join(self.config['save_to'], 'train_true_perps.npy'), self.train_true_perps) self.logger.info('Save final checkpoint') self.save_checkpoint() # Evaluate on test for checkpoint in self.data_manager.checkpoints: self.logger.info('Translate for {}'.format(checkpoint)) dev_file = self.data_manager.dev_files[checkpoint][ self.data_manager.src_lang] test_file = self.data_manager.test_files[checkpoint][ self.data_manager.src_lang] if exists(test_file): self.logger.info(' Evaluate on test') self.restart_to_best_checkpoint(checkpoint) self.validator.translate(self.model, test_file) self.logger.info(' Also translate dev') self.validator.translate(self.model, dev_file) def save_checkpoint(self): cpkt_path = join(self.config['save_to'], '{}.pth'.format(self.config['model_name'])) torch.save(self.model.state_dict(), cpkt_path) def restart_to_best_checkpoint(self, checkpoint): best_perp = numpy.min(self.validator.best_perps[checkpoint]) best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp) self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path)) self.model.load_state_dict(torch.load(best_cpkt_path)) def maybe_validate(self, just_validate=False): if self.num_batches_done % self.validate_freq == 0 or just_validate: self.save_checkpoint() self.validator.validate_and_save(self.model) # if doing annealing if self.config[ 'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0: cond = len( self.validator.perp_curve ) > self.patience and self.validator.perp_curve[-1] > max( self.validator.perp_curve[-1 - self.patience:-1]) if cond: metric = 'perp' scores = self.validator.perp_curve[-1 - self.patience:] scores = map(str, list(scores)) scores = ', '.join(scores) self.logger.info('Past {} are {}'.format(metric, scores)) # when don't use warmup, decay lr if dev not improve if self.lr * self.lr_decay >= self.config['min_lr']: self.logger.info( 'Anneal the learning rate from {} to {}'.format( self.lr, self.lr * self.lr_decay)) self.lr = self.lr * self.lr_decay for p in self.optimizer.param_groups: p['lr'] = self.lr
class Translator(object): def __init__(self, args): super(Translator, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.logger = ut.get_logger(self.config['log_file']) self.num_preload = args.num_preload self.model_file = args.model_file if self.model_file is None: self.model_file = os.path.join(self.config['save_to'], self.config['model_name'] + '.pth') self.input_file = args.input_file if self.input_file is not None and not os.path.exists(self.input_file): raise FileNotFoundError( f'Input file does not exist: {self.input_file}') if not os.path.exists(self.model_file): raise FileNotFoundError( f'Model file does not exist: {self.model_file}') self.logger.info(f'Restore model from {self.model_file}') self.model = Model(self.config, load_from=self.model_file).to(ut.get_device()) if self.input_file: save_fp = os.path.join(self.config['save_to'], os.path.basename(self.input_file)) save_fp = save_fp.rstrip(self.model.data_manager.src_lang) save_fp = save_fp + self.model.data_manager.trg_lang self.best_output_fp = save_fp + '.best_trans' self.beam_output_fp = save_fp + '.beam_trans' open(self.best_output_fp, 'w').close() open(self.beam_output_fp, 'w').close() else: self.best_output_fp = self.beam_output_fp = None self.translate() def translate(self): best_stream = open(self.best_output_fp, 'a') if self.best_output_fp else sys.stdout beam_stream = open(self.beam_output_fp, 'a') if self.beam_output_fp else None self.model.translate(self.input_file or sys.stdin, best_stream, beam_stream, to_ids=True, num_preload=self.num_preload) if self.best_output_fp: best_stream.close() if self.beam_output_fp: beam_stream.close() def plot_head_map(self, mma, target_labels, target_ids, source_labels, source_ids, filename): """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py Change the font in family param below. If the system font is not used, delete matplotlib font cache https://github.com/matplotlib/matplotlib/issues/3590 """ import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt fig, ax = plt.subplots() heatmap = ax.pcolor(mma, cmap=plt.cm.Blues) # put the major ticks at the middle of each cell ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) # without this I get some extra columns rows # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column ax.set_xlim(0, int(mma.shape[1])) ax.set_ylim(0, int(mma.shape[0])) # want a more natural, table-like display ax.invert_yaxis() ax.xaxis.tick_top() # source words -> column labels ax.set_xticklabels(source_labels, minor=False, family='Source Code Pro') for xtick, idx in zip(ax.get_xticklabels(), source_ids): if idx == ac.UNK_ID: xtick.set_color('b') # target words -> row labels ax.set_yticklabels(target_labels, minor=False, family='Source Code Pro') for ytick, idx in zip(ax.get_yticklabels(), target_ids): if idx == ac.UNK_ID: ytick.set_color('b') plt.xticks(rotation=45) plt.tight_layout() plt.savefig(filename) plt.close('all')
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], loss_function: torch.nn.Module = None, beam_size: int = 0, beam_alpha: int = -1, batch_type: str = "sentence" ) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If 0 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = [s for s in data.src] pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_lengths() # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = model.run_batch( batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) assert len(all_outputs) == len(data) if loss_function is not None and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe": valid_sources = [bpe_postprocess(s) for s in valid_sources] valid_references = [bpe_postprocess(v) for v in valid_references] valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references) elif eval_metric.lower() == 'chrf': current_valid_score = chrf(valid_hypotheses, valid_references) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy(valid_hypotheses, valid_references, level=level) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores