Ejemplo n.º 1
0
def plain_factory(data_fn,
                  lm,
                  tokenize_regime,
                  batch_size,
                  device,
                  target_seq_len,
                  corruptor_config=None):
    train_ids = tokens_from_fn(data_fn,
                               lm.vocab,
                               randomize=False,
                               regime=tokenize_regime)
    nb_batches = len(train_ids) // batch_size
    train_streams_provider = CleanStreamsProvider(train_ids)

    if corruptor_config:
        train_streams_provider = corruptor_factory(corruptor_config, lm,
                                                   train_streams_provider)

    batch_former = LazyBatcher(batch_size, train_streams_provider)
    if lm.model.in_len == 1:
        train_data = TemplSplitterClean(target_seq_len, batch_former)
    else:
        raise NotImplementedError(
            "Current data pipeline only supports `in_len==1`.")
    train_data = TransposeWrapper(train_data)
    return OndemandDataProvider(train_data, device), nb_batches
Ejemplo n.º 2
0
    def __init__(self, lm, data_fn, batch_size, target_seq_len, logger=None, tokenize_regime='words'):
        if logger:
            self.logger = logger
        else:
            self.logger = logging.getLogger('EnblockEvaluator')
        self.batch_size = batch_size
        self.lm = lm

        ids = tokens_from_fn(data_fn, lm.vocab, regime=tokenize_regime, randomize=False)
        oov_mask = ids == lm.vocab.unk_ind
        nb_oovs = oov_mask.sum().item()

        nb_tokens = len(ids)
        oov_msg = 'Nb oovs: {} / {} ({:.2f} %)\n'.format(nb_oovs, len(ids), 100.0 * nb_oovs/nb_tokens)
        if nb_oovs / nb_tokens > 0.05:
            self.logger.warning(oov_msg)
        else:
            self.logger.info(oov_msg)

        batched = batchify(ids, batch_size, lm.device == torch.device('cuda:0'))
        data_tb = TemporalSplits(
            batched,
            nb_inputs_necessary=lm.model.in_len,
            nb_targets_parallel=target_seq_len
        )
        self.data = TransposeWrapper(data_tb)
Ejemplo n.º 3
0
    def __init__(self, lm, data_fn, batch_size, target_seq_len, corruptor, nb_rounds, logger=None, tokenize_regime='words'):
        if logger:
            self.logger = logger
        else:
            self.logger = logging.getLogger('SubstitutionalEnblockEvaluator_v2')
        self.batch_size = batch_size
        self.lm = lm
        self.nb_rounds = nb_rounds

        ids = tokens_from_fn(data_fn, lm.vocab, regime=tokenize_regime, randomize=False)
        oov_mask = ids == lm.vocab.unk_ind
        nb_oovs = oov_mask.sum().item()

        nb_tokens = len(ids)
        oov_msg = 'Nb oovs: {} / {} ({:.2f} %)\n'.format(nb_oovs, len(ids), 100.0 * nb_oovs/nb_tokens)
        if nb_oovs / nb_tokens > 0.05:
            self.logger.warning(oov_msg)
        else:
            self.logger.info(oov_msg)

        streams = form_input_targets(ids)
        corrupted_provider = corruptor(streams)
        batch_former = LazyBatcher(batch_size, corrupted_provider)
        data_tb = TemplSplitterClean(target_seq_len, batch_former)

        self.data = CudaStream(TransposeWrapper(data_tb))
Ejemplo n.º 4
0
def main(args):
    logging.basicConfig(level=logging.INFO,
                        format='[%(levelname)s::%(name)s] %(message)s')

    lm = torch.load(args.load, map_location='cpu')

    tokenize_regime = 'words'
    train_ids = tokens_from_fn(args.data,
                               lm.vocab,
                               randomize=False,
                               regime=tokenize_regime)
    train_streams = form_input_targets(train_ids)
    if args.statistics:
        with open(args.statistics, 'rb') as f:
            summary = pickle.load(f)
        confuser = Confuser(summary.confusions, lm.vocab, mincount=5)
        corrupted_provider = StatisticsCorruptor(train_streams,
                                                 confuser,
                                                 args.ins_rate,
                                                 protected=[lm.vocab['</s>']])
    else:
        corrupted_provider = Corruptor(train_streams,
                                       args.subs_rate,
                                       len(lm.vocab),
                                       args.del_rate,
                                       args.ins_rate,
                                       protected=[lm.vocab['</s>']])

    inputs, targets = corrupted_provider.provide()

    for i in range(args.nb_tokens):
        in_word = lm.vocab.i2w(inputs[i].item())
        target_word = lm.vocab.i2w(targets[i].item())

        is_error = i > 0 and inputs[i] != targets[i - 1]
        if args.color and is_error:
            sys.stdout.write(f'{RED_MARK}{in_word}{END_MARK} {target_word}\n')
        else:
            sys.stdout.write(f'{in_word} {target_word}\n')
Ejemplo n.º 5
0
    init_seeds(args.seed, args.cuda)

    print("loading model...")
    lm = torch.load(args.load)
    if args.cuda:
        lm.cuda()
    print(lm.model)

    print("preparing data...")
    tokenize_regime = 'words'
    if args.characters:
        tokenize_regime = 'chars'

    train_ids = tokens_from_fn(args.train,
                               lm.vocab,
                               randomize=False,
                               regime=tokenize_regime)
    train_batched = batchify(train_ids, args.batch_size, args.cuda)
    train_data_tb = TemporalSplits(train_batched,
                                   nb_inputs_necessary=lm.model.in_len,
                                   nb_targets_parallel=args.target_seq_len)
    train_data = TransposeWrapper(train_data_tb)

    valid_ids = tokens_from_fn(args.valid,
                               lm.vocab,
                               randomize=False,
                               regime=tokenize_regime)
    valid_batched = batchify(valid_ids, 10, args.cuda)
    valid_data_tb = TemporalSplits(valid_batched,
                                   nb_inputs_necessary=lm.model.in_len,
                                   nb_targets_parallel=args.target_seq_len)
Ejemplo n.º 6
0
def main(args):
    print(args)
    logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s')

    init_seeds(args.seed, args.cuda)

    print("loading model...")
    lm = torch.load(args.load)
    if args.cuda:
        lm.cuda()
    lm.decoder.core_loss.amount = args.label_smoothing

    print(lm.model)
    print('Label smoothing power', lm.decoder.core_loss.amount)

    tokenize_regime = 'words'

    print("preparing training data...")
    train_ids = tokens_from_fn(args.train, lm.vocab, randomize=False, regime=tokenize_regime)
    train_streams = form_input_targets(train_ids)
    corrupted_provider = InputTargetCorruptor(train_streams, args.subs_rate, args.target_subs_rate, len(lm.vocab), args.del_rate, args.ins_rate, protected=[lm.vocab['</s>']])
    batch_former = LazyBatcher(args.batch_size, corrupted_provider)
    train_data = TemplSplitterClean(args.target_seq_len, batch_former)
    train_data_stream = OndemandDataProvider(TransposeWrapper(train_data), args.cuda)

    print("preparing validation data...")
    evaluator = EnblockEvaluator(lm, args.valid, 10, args.target_seq_len)
    # Evaluation (de facto LR scheduling) with input corruption did not
    # help during the CHiMe-6 evaluation
    # evaluator = SubstitutionalEnblockEvaluator(
    #     lm, args.valid,
    #     batch_size=10, target_seq_len=args.target_seq_len,
    #     corruptor=lambda data: Corruptor(data, args.corruption_rate, len(lm.vocab)),
    #     nb_rounds=args.eval_rounds,
    # )

    def val_loss_fn():
        return evaluator.evaluate().loss_per_token

    print("computing initial PPL...")
    initial_val_loss = val_loss_fn()
    print('Initial perplexity {:.2f}'.format(math.exp(initial_val_loss)))

    print("training...")
    lr = args.lr
    best_val_loss = None

    val_watcher = ValidationWatcher(val_loss_fn, initial_val_loss, args.val_interval, args.workdir, lm)

    optim = torch.optim.SGD(lm.parameters(), lr, weight_decay=args.beta)
    for epoch in range(1, args.epochs + 1):
        logger = ProgressLogger(epoch, args.log_interval, lr, len(list(train_data)) // args.target_seq_len)

        hidden = None
        for X, targets in train_data_stream:
            if hidden is None:
                hidden = lm.model.init_hidden(args.batch_size)

            hidden = repackage_hidden(hidden)

            lm.train()
            output, hidden = lm.model(X, hidden)
            loss, nb_words = lm.decoder.neg_log_prob(output, targets)
            loss /= nb_words

            val_watcher.log_training_update(loss.data, nb_words)

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(lm.parameters(), args.clip)

            optim.step()
            logger.log(loss.data)

        val_loss = val_loss_fn()
        print(epoch_summary(epoch, logger.nb_updates(), logger.time_since_creation(), val_loss))

        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            torch.save(lm, args.save)
            best_val_loss = val_loss
            patience_ticks = 0
        else:
            patience_ticks += 1
            if patience_ticks > args.patience:
                lr /= 2.0
                patience_ticks = 0