def _train(self, loader, steps=0, **kwargs): """Train an epoch of data using either the input loader or using `tf.dataset` In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps to train. We do use a `feed_dict` for passing the `TRAIN_FLAG` in either case :param loader: A data feed :param kwargs: See below :Keyword Arguments: * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True` * *reporting_fns* (`list`) A list of reporting hooks to use :return: Metrics """ SET_TRAIN_FLAG(True) reporting_fns = kwargs.get('reporting_fns', []) pg = create_progress_bar(steps) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0, dtype=tf.int32) nstep_loss = tf.Variable(0.0) nstep_div = tf.Variable(0, dtype=tf.int32) self.nstep_start = time.perf_counter() @tf.function def _train_step(inputs): """Replicated training step.""" features, y = inputs loss = self.optimizer.update(self.model, features, y) batchsz = get_shape_as_list(y)[0] report_loss = loss * batchsz return report_loss, batchsz for inputs in pg(loader): step_report_loss, step_batchsz = _train_step(inputs) epoch_loss.assign_add(step_report_loss) nstep_loss.assign_add(step_report_loss) epoch_div.assign_add(step_batchsz) nstep_div.assign_add(step_batchsz) step = self.optimizer.global_step.numpy() + 1 if step % self.nsteps == 0: metrics = self.calc_metrics(nstep_loss.numpy(), nstep_div.numpy()) self.report( step, metrics, self.nstep_start, 'Train', 'STEP', reporting_fns, self.nsteps ) nstep_loss.assign(0.0) nstep_div.assign(0) self.nstep_start = time.perf_counter() epoch_loss = epoch_loss.numpy() epoch_div = epoch_div.numpy() metrics = self.calc_metrics(epoch_loss, epoch_div) return metrics
def predict_input_fn(): SET_TRAIN_FLAG(False) dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)) dataset = dataset.batch(1) dataset = dataset.map(lambda x, xch, y: ({ 'word': x, 'char': xch, 'lengths': tf.compat.v1.count_nonzero(x, axis=1) }, y)) return dataset
def eval_input_fn(): SET_TRAIN_FLAG(False) dataset = tf.data.Dataset.from_tensor_slices((X_valid, Xch_valid, y_valid)) dataset = dataset.batch(args.batchsz) dataset = dataset.map(lambda x, xch, y: ({ 'word': x, 'char': xch, 'lengths': tf.compat.v1.count_nonzero(x, axis=1) }, y)) return dataset
def recover_last_checkpoint(self): latest = os.path.join(self.base_dir, 'seq2seq-model-tf-%d' % os.getpid()) print('Reloading ' + latest) g = tf.Graph() with g.as_default(): SET_TRAIN_FLAG(None) sess = tf.Session() self.model = self.model.load(latest, predict=True, beam=self.beam, session=sess)
def train_input_fn(): SET_TRAIN_FLAG(True) dataset = tf.data.Dataset.from_tensor_slices((X_train, Xch_train, y_train)) dataset = dataset.shuffle(buffer_size=SHUF_BUF_SZ) dataset = dataset.batch(args.batchsz) dataset = dataset.map(lambda x, xch, y: ({ 'word': x, 'char': xch, 'lengths': tf.compat.v1.count_nonzero(x, axis=1) }, y)) dataset = dataset.prefetch(NUM_PREFETCH) return dataset
def recover_last_checkpoint(self): """Recover the last saved checkpoint :return: None """ latest = os.path.join(self.base_dir, 'seq2seq-model-tf-%d' % os.getpid()) # logger.info('Reloading %s', latest) g = tf.Graph() with g.as_default(): SET_TRAIN_FLAG(None) sess = create_session() self.model = self.model.load(latest, predict=True, beam=self.beam, session=sess)
def _test(self, loader, steps=0, **kwargs): """Test an epoch of data using either the input loader or using `tf.dataset` In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps to train. :param loader: A data feed :param kwargs: See below :Keyword Arguments: * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True` * *reporting_fns* (`list`) A list of reporting hooks to use * *verbose* (`dict`) A dictionary containing `console` boolean and `file` name if on :return: Metrics """ cm = ConfusionMatrix(self.model.labels) total_loss = 0 total_norm = 0 verbose = kwargs.get("verbose", None) pg = create_progress_bar(steps) SET_TRAIN_FLAG(False) for features, y in pg(loader): logits = self.model(features) y_ = tf.argmax(logits, axis=1, output_type=tf.int32) cm.add_batch(y, y_) lossv = tf.compat.v1.losses.sparse_softmax_cross_entropy( labels=y, logits=logits).numpy() batchsz = int(y.shape[0]) assert len(y_) == batchsz total_loss += lossv * batchsz total_norm += batchsz cm.add_batch(y, y_) metrics = cm.get_all_metrics() metrics['avg_loss'] = total_loss / float(total_norm) verbose_output(verbose, cm) return metrics
def main(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_dir", type=str, required=True, help='Training directory') parser.add_argument("--valid_dir", type=str, required=True, help='Validation directory') parser.add_argument( "--train_md", type=str, help="Training metadata YAML, defaults to `{train_dir}/md.yml`") parser.add_argument( "--valid_md", type=str, help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`") parser.add_argument("--label_file", type=str, help="JSON file mapping labels to integers", default="labels.json") parser.add_argument("--dataset_key", default="tlm", help="dataset key for basedir") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument("--distribute", type=str, default="mirror", choices=["mirror", "tpu", "nccl"]) parser.add_argument("--tpu_ep", type=str, help="The TPU endpoint if using `distribute=tpu`") parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--file_type", default='tfrecord', choices=['json', 'tfrecord'], help="Glob pattern for data") parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=True) parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True) parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--ffn_pdrop", type=float, default=0.0, help="Dropout in the dense stack") parser.add_argument("--layer_drop", type=float, default=0.0, help="LayerDrop to apply") parser.add_argument("--optim", default="adamw", type=str, help="Optimizer to use (defaults to adamw)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart", type=str2bool, help="Option allows you to restart from a previous checkpoint") parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch") parser.add_argument( '--rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument( '--rpr_value_on', type=str2bool, default=True, help= "In relative attention, whether add positional correction to values in addition to the " "correction to attention matrix") parser.add_argument('--windowed_ra', type=str2bool, default=False, help="whether prevent attention beyond rpr_k") parser.add_argument("--strategy", help="Training strategy, defaults to `mirror`", choices=["mirror"]) parser.add_argument("--npz", help="Should we write out NPZ files?", type=str2bool, default=False) parser.add_argument("--tb", help="Turn on tensorboard?", type=str2bool, default=False) parser.add_argument( "--convert_only", help="Should we just convert this file to NPZ and exit?", type=str2bool, default=False) args = parser.parse_args() SET_TRAIN_FLAG(True) if args.convert_only: args.restart = True if args.basedir is None: args.basedir = f'lm-{args.dataset_key}-bpe-{os.getpid()}' logging.basicConfig(level=logging.INFO) logger.info(f"Writing results to {args.basedir}") if args.tb: logdir = f"logs/scalars/{os.getpid()}" file_writer = tf.summary.create_file_writer(logdir + "/metrics") file_writer.set_as_default() logger.info(f"Set up tensorboard logdir {logdir}") strategy = create_distribute_strategy(args.distribute, args.tpu_ep) num_replicas = strategy.num_replicas_in_sync logger.info(f"Using {num_replicas} replicas in this job.") vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx) vocab = {'x': vectorizer.vocab} preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) vocabs = preproc_data['vocab'] train_md = args.train_md if args.train_md else os.path.join( args.train_dir, 'md.yml') num_train_samples = get_num_samples(train_md) valid_md = args.valid_md if args.valid_md else os.path.join( args.valid_dir, 'md.yml') num_valid_samples = get_num_samples(valid_md) labels = read_json_tf(args.label_file) num_labels = len(labels) def dataset_train_fn(input_context): global_batchsz = args.batch_size base_batchsz = input_context.get_per_replica_batch_size(global_batchsz) ds = get_dataset(args.train_dir, args.file_type, args.num_train_workers).batch(base_batchsz) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) train_loader = strategy.experimental_distribute_datasets_from_function( dataset_train_fn) def dataset_test_fn(input_context): global_batchsz = args.batch_size base_batchsz = input_context.get_per_replica_batch_size(global_batchsz) ds = get_dataset(args.valid_dir, args.file_type, args.num_train_workers, shuffle=False).batch(base_batchsz) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) valid_loader = strategy.experimental_distribute_datasets_from_function( dataset_test_fn) os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) embeddings = {'x': preproc_data['embeddings']} logger.info("Loaded embeddings") logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) if len(args.rpr_k) == 0 or args.rpr_k[0] < 1: args.rpr_k = None elif len(args.rpr_k) == 1: args.rpr_k = args.rpr_k[0] model = TransformerTagger(num_labels, embeddings, **vars(args)) logger.info("Loaded model and loss") steps_per_epoch = num_train_samples // args.batch_size steps_per_valid_epoch = num_valid_samples // args.batch_size update_on = steps_per_epoch // args.saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs, lr=args.lr) linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay) optimizer = EagerOptimizer(loss_function, optim=args.optim, lr_function=lr_sched, weight_decay=args.weight_decay, clip=args.clip, lr=args.lr) checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer, model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=args.basedir, max_to_keep=5) start_epoch = 0 if args.restart: # The global step gets automatically updated here # so we dont have to worry about our LR regimen checkpoint.restore(checkpoint_manager.latest_checkpoint) current_step = optimizer.global_step start_epoch = current_step // steps_per_epoch def _replicated_train_step(inputs): """This runs on a single replica""" x, y = inputs per_replica_loss = optimizer.update(model, {'x': x}, y, num_replicas) return per_replica_loss @tf.function def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ per_replica_loss = strategy.run(_replicated_train_step, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) def _replicated_test_step(inputs): """This runs on a single replica""" x, y = inputs per_replica_loss = loss_function(model, {'x': x}, y) / num_replicas return per_replica_loss @tf.function def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ per_replica_loss = strategy.run(_replicated_test_step, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) timer = Timer() with strategy.scope(): for epoch in range(start_epoch, args.epochs): SET_TRAIN_FLAG(True) logger.info('Starting epoch %d', epoch + 1) avg_loss = Average('average_train_loss') metrics = {} timer.start() train_iter = iter(train_loader) for i in range(steps_per_epoch): try: loss = _distributed_train_step(next(train_iter)) avg_loss.update(loss.numpy().item()) tf.summary.scalar("train_loss", data=loss, step=optimizer.global_step) except Exception as e: logger.error( f"Exception at training step {i+1}/{steps_per_epoch}. Skipping" ) pass if args.convert_only: logger.warning( "Convert only flag specified. Stopping after one step" ) steps = optimizer.global_step.numpy() npz_checkpoint = os.path.join( args.basedir, f'checkpoint-step-{steps}.npz') save_tlm_output_npz(model, npz_checkpoint) return steps = optimizer.global_step.numpy() if (steps + 1) % report_on == 0: logger.info(avg_loss) if (steps + 1) % update_on == 0: elapsed = timer.elapsed(True) logger.info('elapsed time this epoch %d min', elapsed) logger.info('elapsed step time %f steps/min', i / elapsed) checkpoint_manager.save() if args.npz: npz_checkpoint = os.path.join( args.basedir, f'checkpoint-step-{steps}.npz') save_tlm_output_npz(model, npz_checkpoint) # How much time elapsed in minutes elapsed = timer.elapsed(True) train_token_loss = avg_loss.avg # This is the average training token-level loss across all machines # This is the token-level training perplexity train_token_ppl = math.exp(train_token_loss) metrics['train_elapsed_min'] = elapsed metrics['average_train_loss'] = train_token_loss metrics['train_ppl'] = train_token_ppl metrics['lr'] = float( lr_sched(tf.cast(optimizer.global_step, tf.float32)).numpy().item()) avg_valid_loss = Average('average_valid_loss') timer.start() SET_TRAIN_FLAG(False) valid_iter = iter(valid_loader) for i in range(steps_per_valid_epoch): try: valid_loss = _distributed_test_step(next(valid_iter)) tf.summary.scalar('valid_loss', data=valid_loss, step=optimizer.global_step) avg_valid_loss.update(valid_loss.numpy().item()) except Exception as e: logger.error( f"Exception at validation step {i+1}/{steps_per_valid_epoch}. Skipping" ) pass valid_token_loss = avg_valid_loss.avg valid_token_ppl = math.exp(valid_token_loss) elapsed = timer.elapsed(True) metrics['valid_elapsed_min'] = elapsed metrics['average_valid_loss'] = valid_token_loss metrics['average_valid_word_ppl'] = valid_token_ppl logger.info(json.dumps(metrics, indent=4))
def _train(self, loader, steps=0, **kwargs): """Train an epoch of data using either the input loader or using `tf.dataset` In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps to train. We do use a `feed_dict` for passing the `TRAIN_FLAG` in either case :param loader: A data feed :param kwargs: See below :Keyword Arguments: * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True` * *reporting_fns* (`list`) A list of reporting hooks to use :return: Metrics """ strategy = self.strategy num_replicas = strategy.num_replicas_in_sync def _replicated_train_step(inputs): """Replicated training step.""" features, y = inputs per_replica_loss = self.optimizer.update(self.model, features, y, num_replicas) per_replica_batchsz = tf.cast(get_shape_as_list(y)[0], tf.float32) per_replica_report_loss = per_replica_loss * per_replica_batchsz return per_replica_report_loss, per_replica_batchsz with strategy.scope(): SET_TRAIN_FLAG(True) reporting_fns = kwargs.get('reporting_fns', []) epoch_loss = tf.Variable(0.0) epoch_div = tf.Variable(0.0) nstep_loss = tf.Variable(0.0) nstep_div = tf.Variable(0.0) self.nstep_start = time.time() @tf.function def _distributed_train_step(inputs): per_replica_loss, per_replica_batchsz = strategy.experimental_run_v2( _replicated_train_step, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None), strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_batchsz, axis=None) train_iter = iter(loader) for i in range(steps): step_loss, step_batchsz = _distributed_train_step( next(train_iter)) epoch_loss.assign_add(step_loss) nstep_loss.assign_add(step_loss) epoch_div.assign_add(step_batchsz) nstep_div.assign_add(step_batchsz) step = self.optimizer.global_step.numpy() + 1 if step % self.nsteps == 0: metrics = self.calc_metrics(nstep_loss.numpy(), nstep_div.numpy()) self.report(step, metrics, self.nstep_start, 'Train', 'STEP', reporting_fns, self.nsteps) nstep_loss.assign(0.0) nstep_div.assign(0.0) self.nstep_start = time.time() epoch_loss = epoch_loss.numpy() epoch_div = epoch_div.numpy() metrics = self.calc_metrics(epoch_loss, epoch_div) return metrics
def fit_eager_distributed(model_params, ts, vs, es=None, **kwargs): """ Train a classifier using TensorFlow with `tf.dataset`. This is the default behavior for training. :param model_params: The model (or parameters to create the model) to train :param ts: A training data set :param vs: A validation data set :param es: A test data set, can be None :param kwargs: See below :Keyword Arguments: * *do_early_stopping* (``bool``) -- Stop after evaluation data is no longer improving. Defaults to True * *verbose* (`dict`) A dictionary containing `console` boolean and `file` name if on * *epochs* (``int``) -- how many epochs. Default to 20 * *outfile* -- Model output file, defaults to classifier-model.pyth * *patience* -- How many epochs where evaluation is no longer improving before we give up * *reporting* -- Callbacks which may be used on reporting updates * *nsteps* (`int`) -- If we should report every n-steps, this should be passed * *ema_decay* (`float`) -- If we are doing an exponential moving average, what decay to us4e * *clip* (`int`) -- If we are doing gradient clipping, what value to use * *optim* (`str`) -- The name of the optimizer we are using * *lr* (`float`) -- The learning rate we are using * *mom* (`float`) -- If we are using SGD, what value to use for momentum * *beta1* (`float`) -- Adam-specific hyper-param, defaults to `0.9` * *beta2* (`float`) -- Adam-specific hyper-param, defaults to `0.999` * *epsilon* (`float`) -- Adam-specific hyper-param, defaults to `1e-8 :return: None """ do_early_stopping = bool(kwargs.get('do_early_stopping', True)) #verbose = kwargs.get('verbose', {'console': kwargs.get('verbose_console', False), 'file': kwargs.get('verbose_file', None)}) epochs = int(kwargs.get('epochs', 20)) model_file = get_model_file('classify', 'tf', kwargs.get('basedir')) batchsz = kwargs['batchsz'] lengths_key = model_params.get('lengths_key') test_batchsz = kwargs.get('test_batchsz', batchsz) train_dataset = tf.data.Dataset.from_tensor_slices( to_tensors(ts, lengths_key)) train_dataset = train_dataset.shuffle(buffer_size=SHUF_BUF_SZ) train_dataset = train_dataset.batch(batchsz, drop_remainder=True) train_dataset = train_dataset.prefetch(NUM_PREFETCH) valid_dataset = tf.data.Dataset.from_tensor_slices( to_tensors(vs, lengths_key)) valid_dataset = valid_dataset.batch(batchsz, drop_remainder=True) valid_dataset = valid_dataset.prefetch(NUM_PREFETCH) best_metric = 0 if do_early_stopping: early_stopping_metric = kwargs.get('early_stopping_metric', 'acc') early_stopping_cmp, best_metric = get_metric_cmp( early_stopping_metric, kwargs.get('early_stopping_cmp')) patience = kwargs.get('patience', epochs) print('Doing early stopping on [%s] with patience [%d]' % (early_stopping_metric, patience)) reporting_fns = listify(kwargs.get('reporting', [])) print('reporting', reporting_fns) SET_TRAIN_FLAG(True) trainer = ClassifyTrainerDistributedTf(model_params, **kwargs) train_dataset = trainer.distribute(train_dataset) valid_dataset = trainer.distribute(valid_dataset) last_improved = 0 for epoch in range(epochs): trainer.train(train_dataset, reporting_fns, steps=len(ts)) test_metrics = trainer.test(valid_dataset, reporting_fns, phase='Valid', steps=len(vs)) if do_early_stopping is False: trainer.checkpoint() #trainer.model.save(model_file) elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric): last_improved = epoch best_metric = test_metrics[early_stopping_metric] print('New best %.3f' % best_metric) trainer.checkpoint() #trainer.model.save(model_file) elif (epoch - last_improved) > patience: print('Stopping due to persistent failures to improve') break if do_early_stopping is True: print('Best performance on %s: %.3f at epoch %d' % (early_stopping_metric, best_metric, last_improved)) if es is not None: print('Reloading best checkpoint') trainer.recover_last_checkpoint() trainer.reset_strategy_to_eval() test_dataset = tf.data.Dataset.from_tensor_slices( to_tensors(es, lengths_key)) test_dataset = test_dataset.batch(test_batchsz, drop_remainder=False) test_dataset = test_dataset.prefetch(NUM_PREFETCH) test_dataset = trainer.distribute(test_dataset) trainer.test(test_dataset, reporting_fns, phase='Test', verbose=False, steps=len(es))
def _test(self, loader, steps=0, **kwargs): """Test an epoch of data using either the input loader or using `tf.dataset` In non-`tf.dataset` mode, we cycle the loader data feed, and pull a batch and feed it to the feed dict When we use `tf.dataset`s under the hood, this function simply uses the loader to know how many steps to train. :param loader: A data feed :param kwargs: See below :Keyword Arguments: * *dataset* (`bool`) Set to `True` if using `tf.dataset`s, defaults to `True` * *reporting_fns* (`list`) A list of reporting hooks to use * *verbose* (`dict`) A dictionary containing `console` boolean and `file` name if on :return: Metrics """ strategy = self.strategy #cm = ConfusionMatrix(self.model.labels) #nc = len(self.model.labels) def _replica_test_step(inputs): features, y = inputs y = tf.cast(y, tf.int64) ##per_replica_cm = tf.zeros((nc, nc), dtype=tf.int64) logits = self.model(features) y_ = tf.argmax(logits, axis=1, output_type=tf.int64) ##indices = tf.stack((y, y_), axis=-1) ##dense_shape = tf.cast(tf.shape(per_replica_cm), tf.int64) ##sparse_ups = tf.SparseTensor(indices=indices, values=tf.ones(get_shape_as_list(indices)[0], dtype=tf.int64), ## dense_shape=dense_shape) ##per_replica_cm = tf.compat.v1.sparse_add(per_replica_cm, sparse_ups) per_replica_acc = tf.reduce_sum(tf.cast(y == y_, tf.float32)) per_replica_loss = tf.compat.v1.losses.sparse_softmax_cross_entropy( labels=y, logits=logits) per_replica_batchsz = tf.cast(get_shape_as_list(y)[0], tf.float32) per_replica_report_loss = per_replica_loss * per_replica_batchsz return per_replica_report_loss, per_replica_batchsz, per_replica_acc ##, per_replica_cm @tf.function def _distributed_test_step(inputs): per_replica_loss, per_replica_batchsz, per_replica_acc = strategy.experimental_run_v2( _replica_test_step, args=(inputs, )) step_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) step_batchsz = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_batchsz, axis=None) # step_cm step_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_acc, axis=None) return step_loss, step_batchsz, step_acc #step_cm with strategy.scope(): total_loss = tf.Variable(0.0) total_acc = tf.Variable(0.0) total_norm = tf.Variable(0.0) SET_TRAIN_FLAG(False) test_iter = iter(loader) for i in range(steps): #step_loss, step_batchsz, distributed_cm = _distributed_test_step(next(test_iter)) step_loss, step_batchsz, distributed_acc = _distributed_test_step( next(test_iter)) total_loss.assign_add(step_loss) total_norm.assign_add(step_batchsz) total_acc.assign_add(distributed_acc) #cm._cm += distributed_cm.numpy() #metrics = cm.get_all_metrics() total_loss = total_loss.numpy() total_norm = total_norm.numpy() total_acc = total_acc.numpy() metrics = {} metrics['avg_loss'] = total_loss / float(total_norm) metrics['acc'] = total_acc / float(total_norm) #verbose_output(verbose, cm) return metrics