def plain_factory(data_fn, lm, tokenize_regime, batch_size, device, target_seq_len, corruptor_config=None): train_ids = tokens_from_fn(data_fn, lm.vocab, randomize=False, regime=tokenize_regime) nb_batches = len(train_ids) // batch_size train_streams_provider = CleanStreamsProvider(train_ids) if corruptor_config: train_streams_provider = corruptor_factory(corruptor_config, lm, train_streams_provider) batch_former = LazyBatcher(batch_size, train_streams_provider) if lm.model.in_len == 1: train_data = TemplSplitterClean(target_seq_len, batch_former) else: raise NotImplementedError( "Current data pipeline only supports `in_len==1`.") train_data = TransposeWrapper(train_data) return OndemandDataProvider(train_data, device), nb_batches
def __init__(self, lm, data_fn, batch_size, target_seq_len, logger=None, tokenize_regime='words'): if logger: self.logger = logger else: self.logger = logging.getLogger('EnblockEvaluator') self.batch_size = batch_size self.lm = lm ids = tokens_from_fn(data_fn, lm.vocab, regime=tokenize_regime, randomize=False) oov_mask = ids == lm.vocab.unk_ind nb_oovs = oov_mask.sum().item() nb_tokens = len(ids) oov_msg = 'Nb oovs: {} / {} ({:.2f} %)\n'.format(nb_oovs, len(ids), 100.0 * nb_oovs/nb_tokens) if nb_oovs / nb_tokens > 0.05: self.logger.warning(oov_msg) else: self.logger.info(oov_msg) batched = batchify(ids, batch_size, lm.device == torch.device('cuda:0')) data_tb = TemporalSplits( batched, nb_inputs_necessary=lm.model.in_len, nb_targets_parallel=target_seq_len ) self.data = TransposeWrapper(data_tb)
def __init__(self, lm, data_fn, batch_size, target_seq_len, corruptor, nb_rounds, logger=None, tokenize_regime='words'): if logger: self.logger = logger else: self.logger = logging.getLogger('SubstitutionalEnblockEvaluator_v2') self.batch_size = batch_size self.lm = lm self.nb_rounds = nb_rounds ids = tokens_from_fn(data_fn, lm.vocab, regime=tokenize_regime, randomize=False) oov_mask = ids == lm.vocab.unk_ind nb_oovs = oov_mask.sum().item() nb_tokens = len(ids) oov_msg = 'Nb oovs: {} / {} ({:.2f} %)\n'.format(nb_oovs, len(ids), 100.0 * nb_oovs/nb_tokens) if nb_oovs / nb_tokens > 0.05: self.logger.warning(oov_msg) else: self.logger.info(oov_msg) streams = form_input_targets(ids) corrupted_provider = corruptor(streams) batch_former = LazyBatcher(batch_size, corrupted_provider) data_tb = TemplSplitterClean(target_seq_len, batch_former) self.data = CudaStream(TransposeWrapper(data_tb))
def main(args): logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s') lm = torch.load(args.load, map_location='cpu') tokenize_regime = 'words' train_ids = tokens_from_fn(args.data, lm.vocab, randomize=False, regime=tokenize_regime) train_streams = form_input_targets(train_ids) if args.statistics: with open(args.statistics, 'rb') as f: summary = pickle.load(f) confuser = Confuser(summary.confusions, lm.vocab, mincount=5) corrupted_provider = StatisticsCorruptor(train_streams, confuser, args.ins_rate, protected=[lm.vocab['</s>']]) else: corrupted_provider = Corruptor(train_streams, args.subs_rate, len(lm.vocab), args.del_rate, args.ins_rate, protected=[lm.vocab['</s>']]) inputs, targets = corrupted_provider.provide() for i in range(args.nb_tokens): in_word = lm.vocab.i2w(inputs[i].item()) target_word = lm.vocab.i2w(targets[i].item()) is_error = i > 0 and inputs[i] != targets[i - 1] if args.color and is_error: sys.stdout.write(f'{RED_MARK}{in_word}{END_MARK} {target_word}\n') else: sys.stdout.write(f'{in_word} {target_word}\n')
init_seeds(args.seed, args.cuda) print("loading model...") lm = torch.load(args.load) if args.cuda: lm.cuda() print(lm.model) print("preparing data...") tokenize_regime = 'words' if args.characters: tokenize_regime = 'chars' train_ids = tokens_from_fn(args.train, lm.vocab, randomize=False, regime=tokenize_regime) train_batched = batchify(train_ids, args.batch_size, args.cuda) train_data_tb = TemporalSplits(train_batched, nb_inputs_necessary=lm.model.in_len, nb_targets_parallel=args.target_seq_len) train_data = TransposeWrapper(train_data_tb) valid_ids = tokens_from_fn(args.valid, lm.vocab, randomize=False, regime=tokenize_regime) valid_batched = batchify(valid_ids, 10, args.cuda) valid_data_tb = TemporalSplits(valid_batched, nb_inputs_necessary=lm.model.in_len, nb_targets_parallel=args.target_seq_len)
def main(args): print(args) logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s') init_seeds(args.seed, args.cuda) print("loading model...") lm = torch.load(args.load) if args.cuda: lm.cuda() lm.decoder.core_loss.amount = args.label_smoothing print(lm.model) print('Label smoothing power', lm.decoder.core_loss.amount) tokenize_regime = 'words' print("preparing training data...") train_ids = tokens_from_fn(args.train, lm.vocab, randomize=False, regime=tokenize_regime) train_streams = form_input_targets(train_ids) corrupted_provider = InputTargetCorruptor(train_streams, args.subs_rate, args.target_subs_rate, len(lm.vocab), args.del_rate, args.ins_rate, protected=[lm.vocab['</s>']]) batch_former = LazyBatcher(args.batch_size, corrupted_provider) train_data = TemplSplitterClean(args.target_seq_len, batch_former) train_data_stream = OndemandDataProvider(TransposeWrapper(train_data), args.cuda) print("preparing validation data...") evaluator = EnblockEvaluator(lm, args.valid, 10, args.target_seq_len) # Evaluation (de facto LR scheduling) with input corruption did not # help during the CHiMe-6 evaluation # evaluator = SubstitutionalEnblockEvaluator( # lm, args.valid, # batch_size=10, target_seq_len=args.target_seq_len, # corruptor=lambda data: Corruptor(data, args.corruption_rate, len(lm.vocab)), # nb_rounds=args.eval_rounds, # ) def val_loss_fn(): return evaluator.evaluate().loss_per_token print("computing initial PPL...") initial_val_loss = val_loss_fn() print('Initial perplexity {:.2f}'.format(math.exp(initial_val_loss))) print("training...") lr = args.lr best_val_loss = None val_watcher = ValidationWatcher(val_loss_fn, initial_val_loss, args.val_interval, args.workdir, lm) optim = torch.optim.SGD(lm.parameters(), lr, weight_decay=args.beta) for epoch in range(1, args.epochs + 1): logger = ProgressLogger(epoch, args.log_interval, lr, len(list(train_data)) // args.target_seq_len) hidden = None for X, targets in train_data_stream: if hidden is None: hidden = lm.model.init_hidden(args.batch_size) hidden = repackage_hidden(hidden) lm.train() output, hidden = lm.model(X, hidden) loss, nb_words = lm.decoder.neg_log_prob(output, targets) loss /= nb_words val_watcher.log_training_update(loss.data, nb_words) optim.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(lm.parameters(), args.clip) optim.step() logger.log(loss.data) val_loss = val_loss_fn() print(epoch_summary(epoch, logger.nb_updates(), logger.time_since_creation(), val_loss)) # Save the model if the validation loss is the best we've seen so far. if not best_val_loss or val_loss < best_val_loss: torch.save(lm, args.save) best_val_loss = val_loss patience_ticks = 0 else: patience_ticks += 1 if patience_ticks > args.patience: lr /= 2.0 patience_ticks = 0