def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards, segments_X_shards, valid_lens_x_shards, pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards): mlm_ls, nsp_ls, ls = [], [], [] for (tokens_X_shard, segments_X_shard, valid_lens_x_shard, pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard, nsp_y_shard) in zip(tokens_X_shards, segments_X_shards, valid_lens_x_shards, pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards): # Forward pass _, mlm_Y_hat, nsp_Y_hat = net(tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1), pred_positions_X_shard) # Compute masked language model loss mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1), mlm_weights_X_shard.reshape((-1, 1))) mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8) # Compute next sentence prediction loss nsp_l = loss(nsp_Y_hat, nsp_y_shard) nsp_l = nsp_l.mean() mlm_ls.append(mlm_l) nsp_ls.append(nsp_l) ls.append(mlm_l + nsp_l) npx.waitall() return mlm_ls, nsp_ls, ls
def train(X, contents_Y, styles_Y, ctx, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, ctx, lr, styles_Y) for epoch in range(num_epochs): with autograd.record(): contents_Y_hat, styles_Y_hat = extract_features( X, CONTENT_LAYERS, STYLE_LAYERS) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step(1) npx.waitall() if epoch % lr_decay_epoch == 0: trainer.set_learning_rate(trainer.learning_rate * 0.3) if epoch % 100 == 0: msg = [ f"Epoch: {epoch}", f"contents_l: {float(sum(contents_l)):0.3f}", f"style_l: {float(sum(styles_l)):0.3f}", f"tv_l: {float(tv_l):0.3f}", f"total_l: {float(l):0.3f}" ] msg = ", ".join(msg) print(msg) plt.imshow(postprocess(X).asnumpy()) plt.show() return X
def train(self): self.net.collect_params().reset_ctx(self.mx_ctx) content_x, contents_y = self.get_contents() _, styles_y = self.get_styles() x, styles_y_gram, trainer = self.get_inits(content_x, styles_y) styles_y_gram = [StyleTransferGF.gram(Y) for Y in styles_y] for epoch in range(self.N_EPOCHS): with autograd.record(): contents_y_hat, styles_y_hat = self.extract_features(x) contents_l, styles_l, tv_l, l = self.compute_loss( x, contents_y_hat, styles_y_hat, contents_y, styles_y_gram) l.backward() trainer.step(1) npx.waitall() if epoch % self.LR_DECAY_EPOCH == 0: trainer.set_learning_rate(trainer.learning_rate * 0.3) if epoch % 100 == 0: msg = [ f"Size: {self.IMAGE_SIZE}", f"Epoch: {epoch}", f"contents_l: {float(sum(contents_l)):0.3f}", f"style_l: {float(sum(styles_l)):0.3f}", f"tv_l: {float(tv_l):0.3f}", f"total_l: {float(l):0.3f}" ] msg = ", ".join(msg) print(msg) # plt.imshow(self.postprocess(x).asnumpy()) # plt.show() out = self.postprocess(x).asnumpy() out = (out * 255).astype(numpy.uint8) if self.out_image_filepath is not None: cv.imwrite(self.out_image_filepath, cv.cvtColor(out, cv.COLOR_RGB2BGR)) return out
def batch_check(x, modes, params): state = np.random.normal(0, 1, (1, BAT, L_STA)) for m, p in zip(modes, params): x.attach_grad() with mx.autograd.record(): y = npx.rnn(data=x, parameters=p, mode=m, \ state=state, state_size=L_STA, num_layers=1) assert y.shape == (L_SEQ, BAT, L_STA) y.backward() npx.waitall()
def get_conv_data_mxnet(oc, ic, n, k, p, s): mpx.random.seed(0) data = mp.random.normal(size=(1, ic, n, n)) weight = mp.random.normal(size=(oc, ic, k, k)) bias = mp.zeros((oc, )) on = conv_out_size(n, k, p, s) out = mp.empty((1, oc, on, on)) # Wait data are generated to make later benchmarking accurate mpx.waitall() return data, weight, bias, out
def test_rnn_gru(): L_SEQ, BAT, L_INP, L_STA = 2**20, 4, 2**10, 2 data = np.random.uniform(-1, 1, (L_SEQ, BAT, L_INP)) state = np.random.normal(0, 1, (1, BAT, L_STA)) params = np.random.normal(0, 1, (6168, )) data.attach_grad() with mx.autograd.record(): out = npx.rnn(data=data, parameters=params, mode='gru', \ state=state, state_size=L_STA, num_layers=1) assert out.shape == (L_SEQ, BAT, L_STA) out.backward() npx.waitall()
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, trainer = get_inits(X, device, lr) for epoch in range(1, num_epochs + 1): with autograd.record(): contents_Y_hat, styles_Y_hat = get_features( X, content_layers, style_layers) contents_l, styles_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y) l.backward() trainer.step(1) npx.waitall() if epoch % lr_decay_epoch == 0: trainer.set_learning_rate(trainer.learning_rate * 0.1) if epoch % 100 == 0: print('Total loss: ', l.item()) print('Iteration: ', epoch + 1) plt.imshow(postprocess(X).asnumpy()) plt.axis("off") plt.show() return X
def _create_checkpoint(self, checkpoint_decoder: CheckpointDecoder, time_cost: float, train_iter: data_io.BaseParallelSampleIter, validation_iter: data_io.BaseParallelSampleIter): """ Creates a checkpoint, which will update self.state.converged/self.state.diverged, evaluate validation metrics and update the best known parameters accordingly. """ self.state.checkpoint += 1 # save parameters and evaluate on validation data self._save_params() train_metrics = [lf.metric for lf in self.loss_functions] logger.info( "Checkpoint [%d]\tUpdates=%d Epoch=%d Samples=%d Time-cost=%.3f Updates/sec=%.3f", self.state.checkpoint, self.state.updates, self.state.epoch, self.state.samples, time_cost, self.config.checkpoint_interval / time_cost) logger.info( 'Checkpoint [%d]\t%s', self.state.checkpoint, "\t".join("Train-%s" % str(metric) for metric in train_metrics)) val_metrics = self._evaluate(self.state.checkpoint, validation_iter, checkpoint_decoder) npx.waitall() has_improved = self._determine_improvement(val_metrics) self.state.converged = self._determine_convergence() self.state.diverged = self._determine_divergence(val_metrics) self._adjust_learning_rate(has_improved) if has_improved: self._update_best_params() self._save_trainer_states(self.best_optimizer_states_fname) self._save_lr_scheduler(self.best_lr_scheduler_fname) self._write_and_log_metrics(train_metrics=train_metrics, val_metrics=val_metrics) for metric in train_metrics: metric.reset() self._save_training_state(train_iter) if self.checkpoint_callback: self.checkpoint_callback(self.state.checkpoint)
def __init__(self, file_patterns, file_sampler, dataset_fn=None, batch_sampler_fn=None, dataset_params=None, batch_sampler_params=None, batchify_fn=None, num_dataset_workers=0, num_batch_workers=0, pin_memory=False, circle_length=1, dataset_prefetch=None, batch_prefetch=None, dataset_cached=False, num_max_dataset_cached=0): assert num_dataset_workers >= 0, \ 'num_dataset_workers must be non-negative' assert num_batch_workers >= 0, \ 'num_batch_workers must be non-negative' if num_batch_workers > 0: assert num_dataset_workers > 0, \ 'num_dataset_workers must be positive when num_batch_workers > 0' else: if num_dataset_workers > 0: warnings.warn( 'The multi-processing functionalities for both dataset and' ' batch sampling are disabled when num_batch_workers=0 though ' 'num_dataset_workers={} > 0'.format(num_dataset_workers)) assert circle_length >= 1, \ 'circle_length must be larger than or equal to 1' if dataset_cached: assert num_max_dataset_cached > 0, \ 'When dataset_cached is True, num_max_dataset_cached must be positive' self._dataset = _PathDataset(file_patterns) self._file_sampler = file_sampler assert dataset_fn is not None, 'dataset_fn is not given.' assert batch_sampler_fn is not None, 'batch_sampler_fn is not given.' if dataset_params is not None: self._dataset_fn = partial(dataset_fn, **dataset_params) else: self._dataset_fn = dataset_fn if batch_sampler_params is not None: self._batch_sampler_fn = partial(batch_sampler_fn, **batch_sampler_params) else: self._batch_sampler_fn = batch_sampler_fn self._num_dataset_workers = num_dataset_workers self._num_batch_workers = num_batch_workers self._dataset_prefetch = max( 0, int(dataset_prefetch) if dataset_prefetch is not None else self._num_dataset_workers) self._batch_prefetch = max( 0, int(batch_prefetch) if batch_prefetch is not None else 2 * self._num_batch_workers) self._pin_memory = pin_memory self._circle_length = circle_length self._dataset_cached = dataset_cached self._num_max_dataset_cached = num_max_dataset_cached self._manager = None self._dataset_worker_pool = None if self._num_dataset_workers > 0: npx.waitall() import gc gc.collect() npx.waitall() self._manager = multiprocessing.Manager() self._dataset_worker_pool = multiprocessing.Pool( self._num_dataset_workers, initializer=_initialize_dataset_worker, initargs=[self._manager]) self._batch_worker_pool = None if self._num_batch_workers > 0: npx.waitall() import gc gc.collect() npx.waitall() self._batch_worker_pool = multiprocessing.Pool( self._num_batch_workers) if batchify_fn is None: if self._num_batch_workers > 0: self._batchify_fn = default_mp_batchify_fn else: self._batchify_fn = default_batchify_fn else: self._batchify_fn = batchify_fn
def run_train_translate(train_params: str, translate_params: str, data: Dict[str, Any], use_prepared_data: bool = False, max_seq_len: int = 10, seed: int = 13, use_pytorch: bool = False) -> Dict[str, Any]: """ Train a model and translate a test set. Returns the updated data dictionary containing paths to translation outputs and scores. :param train_params: Command line args for model training. :param translate_params: First command line args for translation. :param data: Dictionary containing test data :param use_prepared_data: Whether to use the prepared data functionality. :param max_seq_len: The maximum sequence length. :param seed: The seed used for training. :param use_pytorch: Whether to use PyTorch. :return: Data dictionary, updated with translation outputs and scores """ if use_pytorch: import sockeye.prepare_data_pt import sockeye.train_pt import sockeye.translate_pt prepare_data_mod = sockeye.prepare_data_pt train_mod = sockeye.train_pt translate_mod = sockeye.translate_pt else: import sockeye.prepare_data import sockeye.train import sockeye.translate prepare_data_mod = sockeye.prepare_data train_mod = sockeye.train translate_mod = sockeye.translate work_dir = os.path.join(data['work_dir'], 'train_translate') data['model'] = os.path.join(work_dir, "model") # Optionally create prepared data directory if use_prepared_data: data['train_prepared'] = os.path.join(work_dir, "prepared_data") prepare_params = "{} {}".format( prepare_data_mod.__file__, PREPARE_DATA_COMMON.format(train_source=data['train_source'], train_target=data['train_target'], output=data['train_prepared'], max_len=max_seq_len)) if 'train_source_factors' in data: prepare_params += TRAIN_WITH_SOURCE_FACTORS_COMMON.format( source_factors=" ".join(data['train_source_factors'])) if 'train_target_factors' in data: prepare_params += TRAIN_WITH_TARGET_FACTORS_COMMON.format( target_factors=" ".join(data['train_target_factors'])) if '--weight-tying-type src_trg' in train_params: prepare_params += ' --shared-vocab' logger.info("Preparing data with parameters %s.", prepare_params) with patch.object(sys, "argv", prepare_params.split()): prepare_data_mod.main() # Train model params = "{} {} {}".format( train_mod.__file__, TRAIN_PARAMS_PREPARED_DATA_COMMON.format( prepared_data=data['train_prepared'], dev_source=data['dev_source'], dev_target=data['dev_target'], model=data['model'], max_len=max_seq_len), train_params) if 'dev_source_factors' in data: params += DEV_WITH_SOURCE_FACTORS_COMMON.format( dev_source_factors=" ".join(data['dev_source_factors'])) if 'dev_target_factors' in data: params += DEV_WITH_TARGET_FACTORS_COMMON.format( dev_target_factors=" ".join(data['dev_target_factors'])) logger.info("Starting training with parameters %s.", train_params) with patch.object(sys, "argv", params.split()): train_mod.main() else: # Train model params = "{} {} {}".format( train_mod.__file__, TRAIN_PARAMS_COMMON.format(train_source=data['train_source'], train_target=data['train_target'], dev_source=data['dev_source'], dev_target=data['dev_target'], model=data['model'], max_len=max_seq_len, seed=seed), train_params) if 'train_source_factors' in data: params += TRAIN_WITH_SOURCE_FACTORS_COMMON.format( source_factors=" ".join(data['train_source_factors'])) if 'train_target_factors' in data: params += TRAIN_WITH_TARGET_FACTORS_COMMON.format( target_factors=" ".join(data['train_target_factors'])) if 'dev_source_factors' in data: params += DEV_WITH_SOURCE_FACTORS_COMMON.format( dev_source_factors=" ".join(data['dev_source_factors'])) if 'dev_target_factors' in data: params += DEV_WITH_TARGET_FACTORS_COMMON.format( dev_target_factors=" ".join(data['dev_target_factors'])) logger.info("Starting training with parameters %s.", train_params) with patch.object(sys, "argv", params.split()): train_mod.main() # create Top-K lexicon from simple ttable mapping digit to digit ttable_path = os.path.join(data['work_dir'], "ttable") generate_fast_align_lex(ttable_path) lexicon_path = os.path.join(data['work_dir'], "lexicon") params = "{} {}".format( sockeye.lexicon.__file__, LEXICON_CREATE_PARAMS_COMMON.format(input=ttable_path, model=data['model'], topk=20, lexicon=lexicon_path)) with patch.object(sys, "argv", params.split()): sockeye.lexicon.main() data['lexicon'] = lexicon_path # Translate corpus with the 1st params and scoring output handler to obtain scores data['test_output'] = os.path.join(work_dir, "test.out") params = "{} {} {}".format( translate_mod.__file__, TRANSLATE_PARAMS_COMMON.format(model=data['model'], input=data['test_source'], output=data['test_output']), translate_params) if 'test_source_factors' in data: params += TRANSLATE_WITH_FACTORS_COMMON.format( input_factors=" ".join(data['test_source_factors'])) # Try to fix transient errors with mxnet tests where parameter file does not yet exist # TODO(migration): remove once mxnet is removed if not use_pytorch: try: from mxnet import npx npx.waitall() import time time.sleep(1) except: pass logger.info("Translating with params %s", params) with patch.object(sys, "argv", params.split()): translate_mod.main() # Collect test inputs with open(data['test_source']) as inputs: data['test_inputs'] = [line.strip() for line in inputs] # Collect test references with open(data['test_target'], "r") as ref: data['test_targets'] = [line.strip() for line in ref] # Collect test translate outputs and scores data['test_outputs'] = collect_translate_output_and_scores( data['test_output']) assert len(data['test_inputs']) == len(data['test_targets']) == len( data['test_outputs']) return data