def on_epoch_end(self, data: Data) -> None: for key, ds_vals in self.test_results.items(): for ds_id, vals in ds_vals.items(): if ds_id != '': d = DSData(ds_id, data) d.write_with_log(key, np.mean(np.array(vals), axis=0)) data.write_with_log( key, np.mean(np.array([e for x in ds_vals.values() for e in x]), axis=0))
def on_batch_end(self, data: Data) -> None: if self.system.ds_id != '': self.fe_per_ds_trace.on_batch_end(DSData(self.system.ds_id, data)) # Block the main process from writing per-instance info since we already have the more detailed key data.per_instance_enabled = False super().on_batch_end(data) data.per_instance_enabled = True
def on_batch_end(self, data: Data) -> None: if self.system.log_steps and (self.system.global_step % self.system.log_steps == 0 or self.system.global_step == 1): if self.system.ds_id != '': data = DSData(self.system.ds_id, data) for key in self.inputs: if key in data: data.write_with_log(key, data[key]) if self.system.global_step > 1: self.elapse_times.append(time.perf_counter() - self.step_start) data.write_with_log( "steps/sec", round(self.system.log_steps / np.sum(self.elapse_times), 2)) self.elapse_times = [] self.step_start = time.perf_counter()
def on_ds_end(self, data: Data) -> None: if self.system.ds_id != '': self.fe_per_ds_trace.on_epoch_end(DSData(self.system.ds_id, data))
def on_batch_begin(self, data: Data) -> None: super().on_batch_begin(data) if self.system.ds_id != '': self.fe_per_ds_trace.on_batch_begin(DSData(self.system.ds_id, data))