def callback(self): meta_info = {"intersect_method": self.model_param.intersect_method, "join_method": self.model_param.join_method} self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric("intersect_count", self.intersect_num), Metric("intersect_rate", self.intersect_rate), Metric("unmatched_count", self.unmatched_num), Metric("unmatched_rate", self.unmatched_rate)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta(name=self.metric_name, metric_type=self.metric_type, extra_metas=meta_info) )
def callback_info(self): class_weight = None classes = None if self.class_weight_dict: class_weight = { str(k): v for k, v in self.class_weight_dict.items() } classes = sorted([str(k) for k in self.class_weight_dict.keys()]) # LOGGER.debug(f"callback class weight is: {class_weight}") metric_meta = MetricMeta(name='train', metric_type=self.metric_type, extra_metas={ "weight_mode": self.weight_mode, "class_weight": class_weight, "classes": classes, "sample_weight_name": self.sample_weight_name }) self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric(self.metric_name, 0)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=metric_meta)
def _display_result(self, block_num=None): if block_num is None: self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric("Coverage", self.coverage), Metric("Block number", self.block_num)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION")) else: self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric("Coverage", self.coverage), Metric("Block number", block_num)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION"))
def fit(self, data_inst, validate_data=None): # init binning obj self.aggregator = HomoBoostArbiterAggregator() self.binning_obj = HomoFeatureBinningServer() self.federated_binning() # initializing self.feature_num = self.sync_feature_num() if self.task_type == consts.CLASSIFICATION: label_mapping = HomoLabelEncoderArbiter().label_alignment() LOGGER.info('label mapping is {}'.format(label_mapping)) self.booster_dim = len( label_mapping) if len(label_mapping) > 2 else 1 if self.n_iter_no_change: self.check_convergence_func = converge_func_factory( "diff", self.tol) # sync start round and end round self.sync_start_round_and_end_round() LOGGER.info('begin to fit a boosting tree') self.preprocess() for epoch_idx in range(self.start_round, self.boosting_round): LOGGER.info('cur epoch idx is {}'.format(epoch_idx)) for class_idx in range(self.booster_dim): model = self.fit_a_learner(epoch_idx, class_idx) global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, )) self.history_loss.append(global_loss) LOGGER.debug('cur epoch global loss is {}'.format(global_loss)) self.callback_metric("loss", "train", [Metric(epoch_idx, global_loss)]) if self.n_iter_no_change: should_stop = self.aggregator.broadcast_converge_status( self.check_convergence, (global_loss, ), suffix=(epoch_idx, )) LOGGER.debug('stop flag sent') if should_stop: break self.callback_meta( "loss", "train", MetricMeta(name="train", metric_type="LOSS", extra_metas={"Best": min(self.history_loss)})) self.postprocess() self.callback_list.on_train_end() self.set_summary(self.generate_summary())
def record_step_best(self, step_best, host_mask, guest_mask, data_instances, model): metas = {"host_mask": host_mask.tolist(), "guest_mask": guest_mask.tolist(), "score_name": self.score_name} metas["number_in"] = int(sum(host_mask) + sum(guest_mask)) metas["direction"] = self.direction metas["n_count"] = int(self.n_count) host_anonym = [ anonymous_generator.generate_anonymous( fid=i, role='host', model=model) for i in range( len(host_mask))] guest_anonym = [ anonymous_generator.generate_anonymous( fid=i, role='guest', model=model) for i in range( len(guest_mask))] metas["host_features_anonym"] = host_anonym metas["guest_features_anonym"] = guest_anonym model_info = self.models_trained[step_best] loss = model_info.get_loss() ic_val = model_info.get_score() metas["loss"] = loss metas["current_ic_val"] = ic_val metas["fit_intercept"] = model.fit_intercept model_key = model_info.get_key() model_dict = self._get_model(model_key) if self.role != consts.ARBITER: all_features = data_instances.schema.get('header') metas["all_features"] = all_features metas["to_enter"] = self.get_to_enter(host_mask, guest_mask, all_features) model_param = list(model_dict.get('model').values())[0].get( model.model_param_name) param_dict = MessageToDict(model_param) metas["intercept"] = param_dict.get("intercept", None) metas["weight"] = param_dict.get("weight", {}) metas["header"] = param_dict.get("header", []) if self.n_step == 0 and self.direction == "forward": metas["intercept"] = self.intercept self.update_summary_client(model, host_mask, guest_mask, all_features, host_anonym, guest_anonym) else: self.update_summary_arbiter(model, loss, ic_val) metric_name = f"stepwise_{self.n_step}" metric = [Metric(metric_name, float(self.n_step))] model.callback_metric(metric_name=metric_name, metric_namespace=self.metric_namespace, metric_data=metric) model.tracker.set_metric_meta(metric_name=metric_name, metric_namespace=self.metric_namespace, metric_meta=MetricMeta(name=metric_name, metric_type=self.metric_type, extra_metas=metas)) LOGGER.info(f"metric_name: {metric_name}, metas: {metas}") return
def callback(self, metas): metric = [Metric(self.metric_name, 0)] self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=metric) self.tracker.set_metric_meta(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_meta=MetricMeta( name=self.metric_name, metric_type=self.metric_type, extra_metas=metas))
def __save_curve_data(self, x_axis_list, y_axis_list, metric_name, metric_namespace): points = [] for i, value in enumerate(x_axis_list): if isinstance(value, float): value = np.round(value, self.round_num) points.append((value, np.round(y_axis_list[i], self.round_num))) points.sort(key=lambda x: x[0]) metric_points = [Metric(point[0], point[1]) for point in points] self.tracker.log_metric_data(metric_namespace, metric_name, metric_points)
def callback_loss(self, iter_num, loss): metric_meta = MetricMeta(name='train', metric_type="LOSS", extra_metas={ "unit_name": "iters", }) self.callback_meta(metric_name='loss', metric_namespace='train', metric_meta=metric_meta) self.callback_metric(metric_name='loss', metric_namespace='train', metric_data=[Metric(iter_num, loss)])
def callback_info(self): metric_meta = MetricMeta( name='train', metric_type=self.metric_type, extra_metas={"label_encoder": self.label_encoder}) self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric(self.metric_name, 0)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=metric_meta)
def callback_dbi(self, iter_num, dbi): metric_meta = MetricMeta(name='train', metric_type="DBI", extra_metas={ "unit_name": "iters", }) self.callback_meta(metric_name='DBI', metric_namespace='train', metric_meta=metric_meta) self.callback_metric(metric_name='DBI', metric_namespace='train', metric_data=[Metric(iter_num, dbi)])
def __save_single_value(self, result, metric_name, metric_namespace, eval_name): metric_type = 'EVALUATION_SUMMARY' if eval_name in consts.ALL_CLUSTER_METRICS: metric_type = 'CLUSTERING_EVALUATION_SUMMARY' self.tracker.log_metric_data( metric_namespace, metric_name, [Metric(eval_name, np.round(result, self.round_num))]) self.tracker.set_metric_meta( metric_namespace, metric_name, MetricMeta(name=metric_name, metric_type=metric_type))
def callback_loss(self, iter_num, loss): # noinspection PyTypeChecker metric_meta = MetricMeta( name="train", metric_type="LOSS", extra_metas={ "unit_name": "iters", }, ) self.callback_meta( metric_name="loss", metric_namespace="train", metric_meta=metric_meta ) self.callback_metric( metric_name="loss", metric_namespace="train", metric_data=[Metric(iter_num, loss)], )
def server_callback_loss(self, iter_num, loss): # noinspection PyTypeChecker metric_meta = MetricMeta( name="train", metric_type="LOSS", extra_metas={ "unit_name": "iters", }, ) self.callback_meta(metric_name="loss", metric_namespace="train", metric_meta=metric_meta) self.callback_metric( metric_name="loss", metric_namespace="train", metric_data=[Metric(iter_num, loss)], ) self._summary["loss_history"].append(loss)
def __sample(self, data_inst, sample_ids=None): """ Random sample method, a line's occur probability is decide by fraction support down sample and up sample if use down sample: should give a float ratio between [0, 1] otherwise: should give a float ratio larger than 1.0 Parameters ---------- data_inst : Table The input data sample_ids : None or list if None, will sample data from the class instance's parameters, otherwise, it will be sample transform process, which means use the samples_ids to generate data Returns ------- new_data_inst: Table the output sample data, same format with input sample_ids: list, return only if sample_ids is None """ LOGGER.info("start to run random sampling") return_sample_ids = False if self.method == "downsample": if sample_ids is None: return_sample_ids = True idset = [ key for key, value in data_inst.mapValues( lambda val: None).collect() ] if self.fraction < 0 or self.fraction > 1: raise ValueError( "sapmle fractions should be a numeric number between 0 and 1inclusive" ) sample_num = max(1, int(self.fraction * len(idset))) sample_ids = resample(idset, replace=False, n_samples=sample_num, random_state=self.random_state) sample_dtable = session.parallelize(zip(sample_ids, range(len(sample_ids))), include_key=True, partition=data_inst.partitions) new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1) callback(self.tracker, "random", [Metric("count", new_data_inst.count())], summary_dict=self._summary_buf) if return_sample_ids: return new_data_inst, sample_ids else: return new_data_inst elif self.method == "upsample": data_set = list(data_inst.collect()) idset = [key for (key, value) in data_set] id_maps = dict(zip(idset, range(len(idset)))) if sample_ids is None: return_sample_ids = True if self.fraction <= 0: raise ValueError( "sapmle fractions should be a numeric number large than 0" ) sample_num = int(self.fraction * len(idset)) sample_ids = resample(idset, replace=True, n_samples=sample_num, random_state=self.random_state) new_data = [] for i in range(len(sample_ids)): index = id_maps[sample_ids[i]] new_data.append((i, data_set[index][1])) new_data_inst = session.parallelize(new_data, include_key=True, partition=data_inst.partitions) callback(self.tracker, "random", [Metric("count", new_data_inst.count())], summary_dict=self._summary_buf) if return_sample_ids: return new_data_inst, sample_ids else: return new_data_inst else: raise ValueError("random sampler not support method {} yet".format( self.method))
def __sample(self, data_inst, sample_ids=None): """ Stratified sample method, a line's occur probability is decide by fractions Input should be Table, every line should be an instance object with label To use this method, a list of ratio should be give, and the list length equals to the number of distinct labels support down sample and up sample if use down sample: should give a list of (category, ratio), where ratio is between [0, 1] otherwise: should give a list (category, ratio), where the float ratio should no less than 1.0 Parameters ---------- data_inst : Table The input data sample_ids : None or list if None, will sample data from the class instance's parameters, otherwise, it will be sample transform process, which means use the samples_ids the generate data Returns ------- new_data_inst: Table the output sample data, sample format with input sample_ids: list, return only if sample_ids is None """ LOGGER.info("start to run stratified sampling") return_sample_ids = False if self.method == "downsample": if sample_ids is None: idset = [[] for i in range(len(self.fractions))] for label, fraction in self.fractions: if fraction < 0 or fraction > 1: raise ValueError( "sapmle fractions should be a numeric number between 0 and 1inclusive" ) return_sample_ids = True for key, inst in data_inst.collect(): label = inst.label if label not in self.label_mapping: raise ValueError( "label not specify sample rate! check it please") idset[self.label_mapping[label]].append(key) sample_ids = [] callback_sample_metrics = [] callback_original_metrics = [] for i in range(len(idset)): label_name = self.labels[i] callback_original_metrics.append( Metric(label_name, len(idset[i]))) if idset[i]: sample_num = max( 1, int(self.fractions[i][1] * len(idset[i]))) _sample_ids = resample(idset[i], replace=False, n_samples=sample_num, random_state=self.random_state) sample_ids.extend(_sample_ids) callback_sample_metrics.append( Metric(label_name, len(_sample_ids))) else: callback_sample_metrics.append(Metric(label_name, 0)) random.shuffle(sample_ids) callback(self.tracker, "stratified", callback_sample_metrics, callback_original_metrics, self._summary_buf) sample_dtable = session.parallelize(zip(sample_ids, range(len(sample_ids))), include_key=True, partition=data_inst.partitions) new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1) if return_sample_ids: return new_data_inst, sample_ids else: return new_data_inst elif self.method == "upsample": data_set = list(data_inst.collect()) ids = [key for (key, inst) in data_set] id_maps = dict(zip(ids, range(len(ids)))) return_sample_ids = False if sample_ids is None: idset = [[] for i in range(len(self.fractions))] for label, fraction in self.fractions: if fraction <= 0: raise ValueError( "sapmle fractions should be a numeric number greater than 0" ) for key, inst in data_set: label = inst.label if label not in self.label_mapping: raise ValueError( "label not specify sample rate! check it please") idset[self.label_mapping[label]].append(key) return_sample_ids = True sample_ids = [] callback_sample_metrics = [] callback_original_metrics = [] for i in range(len(idset)): label_name = self.labels[i] callback_original_metrics.append( Metric(label_name, len(idset[i]))) if idset[i]: sample_num = max( 1, int(self.fractions[i][1] * len(idset[i]))) _sample_ids = resample(idset[i], replace=True, n_samples=sample_num, random_state=self.random_state) sample_ids.extend(_sample_ids) callback_sample_metrics.append( Metric(label_name, len(_sample_ids))) else: callback_sample_metrics.append(Metric(label_name, 0)) random.shuffle(sample_ids) callback(self.tracker, "stratified", callback_sample_metrics, callback_original_metrics, self._summary_buf) new_data = [] for i in range(len(sample_ids)): index = id_maps[sample_ids[i]] new_data.append((i, data_set[index][1])) new_data_inst = session.parallelize(new_data, include_key=True, partition=data_inst.partitions) if return_sample_ids: return new_data_inst, sample_ids else: return new_data_inst else: raise ValueError( "Stratified sampler not support method {} yet".format( self.method))
def fit(self, data): # LOGGER.debug(f"fit receives data is {data}") if not isinstance(data, dict) or len(data) <= 1: raise ValueError( "Union module must receive more than one table as input.") empty_count = 0 combined_table = None combined_schema = None metrics = [] for (key, local_table) in data.items(): LOGGER.debug("table to combine name: {}".format(key)) num_data = local_table.count() LOGGER.debug("table count: {}".format(num_data)) metrics.append(Metric(key, num_data)) self.add_summary(key, num_data) if num_data == 0: LOGGER.warning("Table {} is empty.".format(key)) if combined_table is None: combined_table = local_table combined_schema = local_table.schema empty_count += 1 continue local_is_data_instance = self.check_is_data_instance(local_table) if self.is_data_instance is None or combined_table is None: self.is_data_instance = local_is_data_instance LOGGER.debug(f"self.is_data_instance is {self.is_data_instance}, " f"local_is_data_instance is {local_is_data_instance}") if self.is_data_instance != local_is_data_instance: raise ValueError( f"Cannot combine DataInstance and non-DataInstance object. Union aborted." ) if self.is_data_instance: self.is_empty_feature = data_overview.is_empty_feature( local_table) if self.is_empty_feature: LOGGER.warning("Table {} has empty feature.".format(key)) else: self.check_schema_content(local_table.schema) if combined_table is None or combined_table.count() == 0: # first non-empty table to combine combined_table = local_table combined_schema = local_table.schema if self.keep_duplicate: combined_table = combined_table.map(lambda k, v: (f"{k}_{key}", v)) combined_table.schema = combined_schema else: self.check_id(local_table, combined_table) self.check_label_name(local_table, combined_table) self.check_header(local_table, combined_table) if self.keep_duplicate: local_table = local_table.map(lambda k, v: (f"{k}_{key}", v)) combined_table = combined_table.union(local_table, self._keep_first) combined_table.schema = combined_schema # only check feature length if not empty if self.is_data_instance and not self.is_empty_feature: self.feature_count = len(combined_schema.get("header")) # LOGGER.debug(f"feature count: {self.feature_count}") combined_table.mapValues(self.check_feature_length) if combined_table is None: LOGGER.warning( "All tables provided are empty or have empty features.") first_table = list(data.values())[0] combined_table = first_table.join(first_table) num_data = combined_table.count() metrics.append(Metric("Total", num_data)) self.add_summary("Total", num_data) LOGGER.info(f"Result total data entry count: {num_data}") self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=metrics) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta( name=self.metric_name, metric_type=self.metric_type)) LOGGER.info( "Union operation finished. Total {} empty tables encountered.". format(empty_count)) return combined_table
def fit(self, data_inst, validate_data=None): self.callback_list.on_train_begin(data_inst, validate_data) # collect data from table to form data loader if not self.component_properties.is_warm_start: self._build_model() cur_epoch = 0 else: self.model.warm_start() self.callback_warm_start_init_iter(self.history_iter_epoch) cur_epoch = self.history_iter_epoch + 1 self.prepare_batch_data(self.batch_generator, data_inst) if not self.input_shape: self.model.set_empty() self._set_loss_callback_info() while cur_epoch < self.epochs: self.iter_epoch = cur_epoch LOGGER.debug("cur epoch is {}".format(cur_epoch)) self.callback_list.on_epoch_begin(cur_epoch) epoch_loss = 0 for batch_idx in range(len(self.data_x)): # hetero NN model batch_loss = self.model.train(self.data_x[batch_idx], self.data_y[batch_idx], cur_epoch, batch_idx) epoch_loss += batch_loss epoch_loss /= len(self.data_x) LOGGER.debug("epoch {}' loss is {}".format(cur_epoch, epoch_loss)) self.callback_metric("loss", "train", [Metric(cur_epoch, epoch_loss)]) self.history_loss.append(epoch_loss) self.callback_list.on_epoch_end(cur_epoch) if self.callback_variables.stop_training: LOGGER.debug('early stopping triggered') break if self.hetero_nn_param.selector_param.method: # when use selective bp, loss converge will be disabled is_converge = False else: is_converge = self.converge_func.is_converge(epoch_loss) self._summary_buf["is_converged"] = is_converge self.transfer_variable.is_converge.remote(is_converge, role=consts.HOST, idx=0, suffix=(cur_epoch,)) if is_converge: LOGGER.debug("Training process is converged in epoch {}".format(cur_epoch)) break cur_epoch += 1 if cur_epoch == self.epochs: LOGGER.debug("Training process reach max training epochs {} and not converged".format(self.epochs)) self.callback_list.on_train_end() # if self.validation_strategy and self.validation_strategy.has_saved_best_model(): # self.load_model(self.validation_strategy.cur_best_model) self.set_summary(self._get_model_summary())
def fit(self, data_inst, validate_data=None): LOGGER.debug('in training, partitions is {}'.format( data_inst.partitions)) LOGGER.info('start to fit a ftl model, ' 'run mode is {},' 'communication efficient mode is {}'.format( self.mode, self.comm_eff)) self.check_host_number() data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data( self.init_intersect_obj(), data_inst, guest_side=True) self.input_dim = self.x_shape[0] # cache data_loader for faster validation self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader self.partitions = data_inst.partitions LOGGER.debug('self partitions is {}'.format(self.partitions)) self.initialize_nn(input_shape=self.x_shape) self.feat_dim = self.nn._model.output_shape[1] self.constant_k = 1 / self.feat_dim self.callback_list.on_train_begin(train_data=data_inst, validate_data=validate_data) self.callback_meta( "loss", "train", MetricMeta(name="train", metric_type="LOSS", extra_metas={"unit_name": "iters"})) # compute intermediate result of first epoch self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components( data_loader) for epoch_idx in range(self.epochs): LOGGER.debug('fitting epoch {}'.format(epoch_idx)) self.callback_list.on_epoch_begin(epoch_idx) host_components = self.exchange_components(self.send_components, epoch_idx=epoch_idx) loss = None for local_round_idx in range(self.local_round): if self.comm_eff: LOGGER.debug( 'running local iter {}'.format(local_round_idx)) grads = self.compute_backward_gradients( host_components, data_loader, epoch_idx=epoch_idx, local_round=local_round_idx) self.update_nn_weights(grads, data_loader, epoch_idx, decay=self.comm_eff) if local_round_idx == 0: loss = self.compute_loss( host_components, epoch_idx, len(data_loader.get_overlap_indexes())) if local_round_idx + 1 != self.local_round: self.phi, self.overlap_ua = self.compute_phi_and_overlap_ua( data_loader) self.callback_metric("loss", "train", [Metric(epoch_idx, loss)]) self.history_loss.append(loss) # updating variables for next epochs if epoch_idx + 1 == self.epochs: # only need to update phi in last epochs self.phi, _ = self.compute_phi_and_overlap_ua(data_loader) else: # compute phi, phi_product, overlap_ua etc. for next epoch self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components( data_loader) self.callback_list.on_epoch_end(epoch_idx) # check n_iter_no_change if self.n_iter_no_change is True: if self.check_convergence(loss): self.sync_stop_flag(epoch_idx, stop_flag=True) break else: self.sync_stop_flag(epoch_idx, stop_flag=False) LOGGER.debug('fitting epoch {} done, loss is {}'.format( epoch_idx, loss)) self.callback_list.on_train_end() self.callback_meta( "loss", "train", MetricMeta(name="train", metric_type="LOSS", extra_metas={"Best": min(self.history_loss)})) self.set_summary(self.generate_summary()) LOGGER.debug('fitting ftl model done')
def fit(self, data_inst, validate_data=None): LOGGER.info('begin to fit a hetero boosting model, model is {}'.format( self.model_name)) self.start_round = 0 self.on_training = True self.data_inst = data_inst self.data_bin, self.bin_split_points, self.bin_sparse_points = self.prepare_data( data_inst) self.y = self.get_label(self.data_bin) if not self.is_warm_start: self.feature_name_fid_mapping = self.gen_feature_fid_mapping( data_inst.schema) self.classes_, self.num_classes, self.booster_dim = self.check_label( ) self.loss = self.get_loss_function() self.y_hat, self.init_score = self.get_init_score( self.y, self.num_classes) else: classes_, num_classes, booster_dim = self.check_label() self.prepare_warm_start(data_inst, classes_) LOGGER.info('class index is {}'.format(self.classes_)) self.sync_booster_dim() self.generate_encrypter() self.callback_list.on_train_begin(data_inst, validate_data) self.callback_meta( "loss", "train", MetricMeta(name="train", metric_type="LOSS", extra_metas={"unit_name": "iters"})) self.preprocess() for epoch_idx in range(self.start_round, self.boosting_round): LOGGER.info('cur epoch idx is {}'.format(epoch_idx)) self.callback_list.on_epoch_begin(epoch_idx) for class_idx in range(self.booster_dim): # fit a booster model = self.fit_a_learner(epoch_idx, class_idx) booster_meta, booster_param = model.get_model() if booster_meta is not None and booster_param is not None: self.booster_meta = booster_meta self.boosting_model_list.append(booster_param) # update predict score cur_sample_weights = model.get_sample_weights() self.y_hat = self.get_new_predict_score(self.y_hat, cur_sample_weights, dim=class_idx) # compute loss loss = self.compute_loss(self.y_hat, self.y) self.history_loss.append(loss) LOGGER.info("round {} loss is {}".format(epoch_idx, loss)) self.callback_metric("loss", "train", [Metric(epoch_idx, loss)]) # check validation validation_strategy = self.callback_list.get_validation_strategy() if validation_strategy: validation_strategy.set_precomputed_train_scores( self.score_to_predict_result(data_inst, self.y_hat)) self.callback_list.on_epoch_end(epoch_idx) should_stop = False if self.n_iter_no_change and self.check_convergence(loss): should_stop = True self.is_converged = True self.sync_stop_flag(self.is_converged, epoch_idx) if self.stop_training or should_stop: break self.postprocess() self.callback_list.on_train_end() self.callback_meta( "loss", "train", MetricMeta(name="train", metric_type="LOSS", extra_metas={"Best": min(self.history_loss)})) # get summary self.set_summary(self.generate_summary())