示例#1
0
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
                         segments_X_shards, valid_lens_x_shards,
                         pred_positions_X_shards, mlm_weights_X_shards,
                         mlm_Y_shards, nsp_y_shards):
    mlm_ls, nsp_ls, ls = [], [], []
    for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
         pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
         nsp_y_shard) in zip(tokens_X_shards, segments_X_shards,
                             valid_lens_x_shards, pred_positions_X_shards,
                             mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards):
        # Forward pass
        _, mlm_Y_hat, nsp_Y_hat = net(tokens_X_shard, segments_X_shard,
                                      valid_lens_x_shard.reshape(-1),
                                      pred_positions_X_shard)
        # Compute masked language model loss
        mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)),
                     mlm_Y_shard.reshape(-1),
                     mlm_weights_X_shard.reshape((-1, 1)))
        mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
        # Compute next sentence prediction loss
        nsp_l = loss(nsp_Y_hat, nsp_y_shard)
        nsp_l = nsp_l.mean()
        mlm_ls.append(mlm_l)
        nsp_ls.append(nsp_l)
        ls.append(mlm_l + nsp_l)
        npx.waitall()
    return mlm_ls, nsp_ls, ls
示例#2
0
def train(X, contents_Y, styles_Y, ctx, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, ctx, lr, styles_Y)
    for epoch in range(num_epochs):
        with autograd.record():
            contents_Y_hat, styles_Y_hat = extract_features(
                X, CONTENT_LAYERS, STYLE_LAYERS)
            contents_l, styles_l, tv_l, l = compute_loss(
                X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step(1)
        npx.waitall()
        if epoch % lr_decay_epoch == 0:
            trainer.set_learning_rate(trainer.learning_rate * 0.3)
        if epoch % 100 == 0:
            msg = [
                f"Epoch: {epoch}",
                f"contents_l: {float(sum(contents_l)):0.3f}",
                f"style_l: {float(sum(styles_l)):0.3f}",
                f"tv_l: {float(tv_l):0.3f}", f"total_l: {float(l):0.3f}"
            ]
            msg = ", ".join(msg)
            print(msg)
            plt.imshow(postprocess(X).asnumpy())
            plt.show()
    return X
 def train(self):
     self.net.collect_params().reset_ctx(self.mx_ctx)
     content_x, contents_y = self.get_contents()
     _, styles_y = self.get_styles()
     x, styles_y_gram, trainer = self.get_inits(content_x, styles_y)
     styles_y_gram = [StyleTransferGF.gram(Y) for Y in styles_y]
     for epoch in range(self.N_EPOCHS):
         with autograd.record():
             contents_y_hat, styles_y_hat = self.extract_features(x)
             contents_l, styles_l, tv_l, l = self.compute_loss(
                 x, contents_y_hat, styles_y_hat, contents_y, styles_y_gram)
         l.backward()
         trainer.step(1)
         npx.waitall()
         if epoch % self.LR_DECAY_EPOCH == 0:
             trainer.set_learning_rate(trainer.learning_rate * 0.3)
         if epoch % 100 == 0:
             msg = [
                 f"Size: {self.IMAGE_SIZE}", f"Epoch: {epoch}",
                 f"contents_l: {float(sum(contents_l)):0.3f}",
                 f"style_l: {float(sum(styles_l)):0.3f}",
                 f"tv_l: {float(tv_l):0.3f}", f"total_l: {float(l):0.3f}"
             ]
             msg = ", ".join(msg)
             print(msg)
             # plt.imshow(self.postprocess(x).asnumpy())
             # plt.show()
     out = self.postprocess(x).asnumpy()
     out = (out * 255).astype(numpy.uint8)
     if self.out_image_filepath is not None:
         cv.imwrite(self.out_image_filepath,
                    cv.cvtColor(out, cv.COLOR_RGB2BGR))
     return out
 def batch_check(x, modes, params):
     state = np.random.normal(0, 1, (1, BAT, L_STA))
     for m, p in zip(modes, params):
         x.attach_grad()
         with mx.autograd.record():
             y = npx.rnn(data=x, parameters=p, mode=m, \
                 state=state, state_size=L_STA, num_layers=1)
         assert y.shape == (L_SEQ, BAT, L_STA)
         y.backward()
         npx.waitall()
示例#5
0
def get_conv_data_mxnet(oc, ic, n, k, p, s):
    mpx.random.seed(0)
    data = mp.random.normal(size=(1, ic, n, n))
    weight = mp.random.normal(size=(oc, ic, k, k))
    bias = mp.zeros((oc, ))
    on = conv_out_size(n, k, p, s)
    out = mp.empty((1, oc, on, on))
    # Wait data are generated to make later benchmarking accurate
    mpx.waitall()
    return data, weight, bias, out
def test_rnn_gru():
    L_SEQ, BAT, L_INP, L_STA = 2**20, 4, 2**10, 2
    data = np.random.uniform(-1, 1, (L_SEQ, BAT, L_INP))
    state = np.random.normal(0, 1, (1, BAT, L_STA))
    params = np.random.normal(0, 1, (6168, ))
    data.attach_grad()
    with mx.autograd.record():
        out = npx.rnn(data=data, parameters=params, mode='gru', \
            state=state, state_size=L_STA, num_layers=1)
    assert out.shape == (L_SEQ, BAT, L_STA)
    out.backward()
    npx.waitall()
示例#7
0
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, trainer = get_inits(X, device, lr)
    for epoch in range(1, num_epochs + 1):
        with autograd.record():
            contents_Y_hat, styles_Y_hat = get_features(
                X, content_layers, style_layers)
            contents_l, styles_l, l = compute_loss(X, contents_Y_hat,
                                                   styles_Y_hat, contents_Y,
                                                   styles_Y)
        l.backward()
        trainer.step(1)
        npx.waitall()
        if epoch % lr_decay_epoch == 0:
            trainer.set_learning_rate(trainer.learning_rate * 0.1)
        if epoch % 100 == 0:
            print('Total loss: ', l.item())
            print('Iteration: ', epoch + 1)
            plt.imshow(postprocess(X).asnumpy())
            plt.axis("off")
            plt.show()
    return X
示例#8
0
 def _create_checkpoint(self, checkpoint_decoder: CheckpointDecoder,
                        time_cost: float,
                        train_iter: data_io.BaseParallelSampleIter,
                        validation_iter: data_io.BaseParallelSampleIter):
     """
     Creates a checkpoint, which will update self.state.converged/self.state.diverged, evaluate validation
     metrics and update the best known parameters accordingly.
     """
     self.state.checkpoint += 1
     # save parameters and evaluate on validation data
     self._save_params()
     train_metrics = [lf.metric for lf in self.loss_functions]
     logger.info(
         "Checkpoint [%d]\tUpdates=%d Epoch=%d Samples=%d Time-cost=%.3f Updates/sec=%.3f",
         self.state.checkpoint, self.state.updates, self.state.epoch,
         self.state.samples, time_cost,
         self.config.checkpoint_interval / time_cost)
     logger.info(
         'Checkpoint [%d]\t%s', self.state.checkpoint,
         "\t".join("Train-%s" % str(metric) for metric in train_metrics))
     val_metrics = self._evaluate(self.state.checkpoint, validation_iter,
                                  checkpoint_decoder)
     npx.waitall()
     has_improved = self._determine_improvement(val_metrics)
     self.state.converged = self._determine_convergence()
     self.state.diverged = self._determine_divergence(val_metrics)
     self._adjust_learning_rate(has_improved)
     if has_improved:
         self._update_best_params()
         self._save_trainer_states(self.best_optimizer_states_fname)
         self._save_lr_scheduler(self.best_lr_scheduler_fname)
     self._write_and_log_metrics(train_metrics=train_metrics,
                                 val_metrics=val_metrics)
     for metric in train_metrics:
         metric.reset()
     self._save_training_state(train_iter)
     if self.checkpoint_callback:
         self.checkpoint_callback(self.state.checkpoint)
示例#9
0
    def __init__(self,
                 file_patterns,
                 file_sampler,
                 dataset_fn=None,
                 batch_sampler_fn=None,
                 dataset_params=None,
                 batch_sampler_params=None,
                 batchify_fn=None,
                 num_dataset_workers=0,
                 num_batch_workers=0,
                 pin_memory=False,
                 circle_length=1,
                 dataset_prefetch=None,
                 batch_prefetch=None,
                 dataset_cached=False,
                 num_max_dataset_cached=0):
        assert num_dataset_workers >= 0, \
            'num_dataset_workers must be non-negative'
        assert num_batch_workers >= 0, \
            'num_batch_workers must be non-negative'
        if num_batch_workers > 0:
            assert num_dataset_workers > 0, \
                'num_dataset_workers must be positive when num_batch_workers > 0'
        else:
            if num_dataset_workers > 0:
                warnings.warn(
                    'The multi-processing functionalities for both dataset and'
                    ' batch sampling are disabled when num_batch_workers=0 though '
                    'num_dataset_workers={} > 0'.format(num_dataset_workers))
        assert circle_length >= 1, \
            'circle_length must be larger than or equal to 1'
        if dataset_cached:
            assert num_max_dataset_cached > 0, \
                'When dataset_cached is True, num_max_dataset_cached must be positive'

        self._dataset = _PathDataset(file_patterns)
        self._file_sampler = file_sampler

        assert dataset_fn is not None, 'dataset_fn is not given.'
        assert batch_sampler_fn is not None, 'batch_sampler_fn is not given.'
        if dataset_params is not None:
            self._dataset_fn = partial(dataset_fn, **dataset_params)
        else:
            self._dataset_fn = dataset_fn
        if batch_sampler_params is not None:
            self._batch_sampler_fn = partial(batch_sampler_fn,
                                             **batch_sampler_params)
        else:
            self._batch_sampler_fn = batch_sampler_fn

        self._num_dataset_workers = num_dataset_workers
        self._num_batch_workers = num_batch_workers
        self._dataset_prefetch = max(
            0,
            int(dataset_prefetch)
            if dataset_prefetch is not None else self._num_dataset_workers)
        self._batch_prefetch = max(
            0,
            int(batch_prefetch) if batch_prefetch is not None else 2 *
            self._num_batch_workers)

        self._pin_memory = pin_memory
        self._circle_length = circle_length
        self._dataset_cached = dataset_cached
        self._num_max_dataset_cached = num_max_dataset_cached

        self._manager = None
        self._dataset_worker_pool = None
        if self._num_dataset_workers > 0:
            npx.waitall()
            import gc
            gc.collect()
            npx.waitall()
            self._manager = multiprocessing.Manager()
            self._dataset_worker_pool = multiprocessing.Pool(
                self._num_dataset_workers,
                initializer=_initialize_dataset_worker,
                initargs=[self._manager])
        self._batch_worker_pool = None
        if self._num_batch_workers > 0:
            npx.waitall()
            import gc
            gc.collect()
            npx.waitall()
            self._batch_worker_pool = multiprocessing.Pool(
                self._num_batch_workers)
        if batchify_fn is None:
            if self._num_batch_workers > 0:
                self._batchify_fn = default_mp_batchify_fn
            else:
                self._batchify_fn = default_batchify_fn
        else:
            self._batchify_fn = batchify_fn
示例#10
0
def run_train_translate(train_params: str,
                        translate_params: str,
                        data: Dict[str, Any],
                        use_prepared_data: bool = False,
                        max_seq_len: int = 10,
                        seed: int = 13,
                        use_pytorch: bool = False) -> Dict[str, Any]:
    """
    Train a model and translate a test set. Returns the updated data dictionary containing paths to translation outputs
    and scores.

    :param train_params: Command line args for model training.
    :param translate_params: First command line args for translation.
    :param data: Dictionary containing test data
    :param use_prepared_data: Whether to use the prepared data functionality.
    :param max_seq_len: The maximum sequence length.
    :param seed: The seed used for training.
    :param use_pytorch: Whether to use PyTorch.
    :return: Data dictionary, updated with translation outputs and scores
    """
    if use_pytorch:
        import sockeye.prepare_data_pt
        import sockeye.train_pt
        import sockeye.translate_pt
        prepare_data_mod = sockeye.prepare_data_pt
        train_mod = sockeye.train_pt
        translate_mod = sockeye.translate_pt
    else:
        import sockeye.prepare_data
        import sockeye.train
        import sockeye.translate
        prepare_data_mod = sockeye.prepare_data
        train_mod = sockeye.train
        translate_mod = sockeye.translate

    work_dir = os.path.join(data['work_dir'], 'train_translate')
    data['model'] = os.path.join(work_dir, "model")
    # Optionally create prepared data directory
    if use_prepared_data:
        data['train_prepared'] = os.path.join(work_dir, "prepared_data")
        prepare_params = "{} {}".format(
            prepare_data_mod.__file__,
            PREPARE_DATA_COMMON.format(train_source=data['train_source'],
                                       train_target=data['train_target'],
                                       output=data['train_prepared'],
                                       max_len=max_seq_len))
        if 'train_source_factors' in data:
            prepare_params += TRAIN_WITH_SOURCE_FACTORS_COMMON.format(
                source_factors=" ".join(data['train_source_factors']))
        if 'train_target_factors' in data:
            prepare_params += TRAIN_WITH_TARGET_FACTORS_COMMON.format(
                target_factors=" ".join(data['train_target_factors']))

        if '--weight-tying-type src_trg' in train_params:
            prepare_params += ' --shared-vocab'

        logger.info("Preparing data with parameters %s.", prepare_params)
        with patch.object(sys, "argv", prepare_params.split()):
            prepare_data_mod.main()
        # Train model
        params = "{} {} {}".format(
            train_mod.__file__,
            TRAIN_PARAMS_PREPARED_DATA_COMMON.format(
                prepared_data=data['train_prepared'],
                dev_source=data['dev_source'],
                dev_target=data['dev_target'],
                model=data['model'],
                max_len=max_seq_len), train_params)

        if 'dev_source_factors' in data:
            params += DEV_WITH_SOURCE_FACTORS_COMMON.format(
                dev_source_factors=" ".join(data['dev_source_factors']))
        if 'dev_target_factors' in data:
            params += DEV_WITH_TARGET_FACTORS_COMMON.format(
                dev_target_factors=" ".join(data['dev_target_factors']))

        logger.info("Starting training with parameters %s.", train_params)
        with patch.object(sys, "argv", params.split()):
            train_mod.main()
    else:
        # Train model
        params = "{} {} {}".format(
            train_mod.__file__,
            TRAIN_PARAMS_COMMON.format(train_source=data['train_source'],
                                       train_target=data['train_target'],
                                       dev_source=data['dev_source'],
                                       dev_target=data['dev_target'],
                                       model=data['model'],
                                       max_len=max_seq_len,
                                       seed=seed), train_params)

        if 'train_source_factors' in data:
            params += TRAIN_WITH_SOURCE_FACTORS_COMMON.format(
                source_factors=" ".join(data['train_source_factors']))
        if 'train_target_factors' in data:
            params += TRAIN_WITH_TARGET_FACTORS_COMMON.format(
                target_factors=" ".join(data['train_target_factors']))
        if 'dev_source_factors' in data:
            params += DEV_WITH_SOURCE_FACTORS_COMMON.format(
                dev_source_factors=" ".join(data['dev_source_factors']))
        if 'dev_target_factors' in data:
            params += DEV_WITH_TARGET_FACTORS_COMMON.format(
                dev_target_factors=" ".join(data['dev_target_factors']))

        logger.info("Starting training with parameters %s.", train_params)
        with patch.object(sys, "argv", params.split()):
            train_mod.main()

    # create Top-K lexicon from simple ttable mapping digit to digit
    ttable_path = os.path.join(data['work_dir'], "ttable")
    generate_fast_align_lex(ttable_path)
    lexicon_path = os.path.join(data['work_dir'], "lexicon")
    params = "{} {}".format(
        sockeye.lexicon.__file__,
        LEXICON_CREATE_PARAMS_COMMON.format(input=ttable_path,
                                            model=data['model'],
                                            topk=20,
                                            lexicon=lexicon_path))
    with patch.object(sys, "argv", params.split()):
        sockeye.lexicon.main()
    data['lexicon'] = lexicon_path

    # Translate corpus with the 1st params and scoring output handler to obtain scores
    data['test_output'] = os.path.join(work_dir, "test.out")
    params = "{} {} {}".format(
        translate_mod.__file__,
        TRANSLATE_PARAMS_COMMON.format(model=data['model'],
                                       input=data['test_source'],
                                       output=data['test_output']),
        translate_params)

    if 'test_source_factors' in data:
        params += TRANSLATE_WITH_FACTORS_COMMON.format(
            input_factors=" ".join(data['test_source_factors']))

    # Try to fix transient errors with mxnet tests where parameter file does not yet exist
    # TODO(migration): remove once mxnet is removed
    if not use_pytorch:
        try:
            from mxnet import npx
            npx.waitall()
            import time
            time.sleep(1)
        except:
            pass

    logger.info("Translating with params %s", params)
    with patch.object(sys, "argv", params.split()):
        translate_mod.main()

    # Collect test inputs
    with open(data['test_source']) as inputs:
        data['test_inputs'] = [line.strip() for line in inputs]

    # Collect test references
    with open(data['test_target'], "r") as ref:
        data['test_targets'] = [line.strip() for line in ref]

    # Collect test translate outputs and scores
    data['test_outputs'] = collect_translate_output_and_scores(
        data['test_output'])
    assert len(data['test_inputs']) == len(data['test_targets']) == len(
        data['test_outputs'])
    return data