def train_mnist(config, checkpoint_dir=None): strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() per_worker_batch_size = 64 num_workers = get_num_workers() global_batch_size = per_worker_batch_size * num_workers multi_worker_dataset = mnist_dataset(global_batch_size) with strategy.scope(): multi_worker_model = build_and_compile_cnn_model(config) multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=70, callbacks=[ TuneReportCheckpointCallback( {"mean_accuracy": "accuracy"}) ])
def train_model(self): ''' ---- This method overrides the one from Deep_CellSpike. It is almost exactly identical ---- except the callbacks_list now contains a ---- TuneReportCheckpointCallback() object The TuneReportCheckpointCallback ---- class handles reporting metrics and checkpoints. The metric it reports is validation loss. ''' if self.verb == 0: print('train silently, myRank=', self.myRank) hpar = self.hparams callbacks_list = [TuneReportCheckpointCallback(metrics=['val_loss'])] if self.useHorovod: # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. callbacks_list.append( self.hvd.callbacks.BroadcastGlobalVariablesCallback(0)) # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard or other metrics-based callbacks. callbacks_list.append(self.hvd.callbacks.MetricAverageCallback()) #Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` before # the first five epochs. See https://arxiv.org/abs/1706.02677 for details. mutliAgent = hpar['train_conf']['multiAgent'] if mutliAgent['warmup_epochs'] > 0: callbacks_list.append( self.hvd.callbacks.LearningRateWarmupCallback( warmup_epochs=mutliAgent['warmup_epochs'], verbose=self.verb)) if self.verb: print('added LearningRateWarmupCallback(%d epochs)' % mutliAgent['warmup_epochs']) lrCb = MyLearningTracker() callbacks_list.append(lrCb) trlsCb = None if self.train_loss_EOE: print('enable end-epoch true train-loss as callBack') assert self.cellName == None # not tested genConf = copy.deepcopy(self.sumRec['inpGen']['train']) # we need much less stats for the EOE loss: genConf['numLocalSamples'] //= 8 # needs less data genConf['name'] = 'EOF' + genConf['name'] inpGen = CellSpike_input_generator(genConf, verb=1) trlsCb = MyEpochEndLoss(inpGen) callbacks_list.append(trlsCb) if self.checkPtOn and self.myRank == 0: outF5w = self.outPath + '/' + self.prjName + '.weights_best.h5' chkPer = 1 ckpt = ModelCheckpoint(outF5w, save_best_only=True, save_weights_only=True, verbose=1, period=chkPer, monitor='val_loss') callbacks_list.append(ckpt) if self.verb: print('enabled ModelCheckpoint, save_freq=', chkPer) LRconf = hpar['train_conf']['LRconf'] if LRconf['patience'] > 0: [pati, fact] = LRconf['patience'], LRconf['reduceFactor'] redu_lr = ReduceLROnPlateau(monitor='val_loss', factor=fact, patience=pati, min_lr=0.0, verbose=self.verb, min_delta=LRconf['min_delta']) callbacks_list.append(redu_lr) if self.verb: print('enabled ReduceLROnPlateau, patience=%d, factor =%.2f' % (pati, fact)) if self.earlyStopPatience > 0: earlyStop = EarlyStopping(monitor='val_loss', patience=self.earlyStopPatience, verbose=self.verb, min_delta=LRconf['min_delta']) callbacks_list.append(earlyStop) if self.verb: print('enabled EarlyStopping, patience=', self.earlyStopPatience) #pprint(hpar) if self.verb: print('\nTrain_model goalEpochs=%d' % (self.goalEpochs), ' modelDesign=', self.modelDesign, 'localBS=', hpar['train_conf']['localBS'], 'globBS=', hpar['train_conf']['localBS'] * self.numRanks) fitVerb = 1 # prints live: [=========>....] - ETA: xx s if self.verb == 2: fitVerb = 1 if self.verb == 0: fitVerb = 2 # prints 1-line summary at epoch end if self.numRanks > 1: # change the logic if self.verb: fitVerb = 2 else: fitVerb = 0 # keras is silent startTm = time.time() hir = self.model.fit( # TF-2.1 self.inpGenD['train'], callbacks=callbacks_list, epochs=self.goalEpochs, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=self.shuffle_data, verbose=fitVerb, validation_data=self.inpGenD['val']) fitTime = time.time() - startTm hir = hir.history totEpochs = len(hir['loss']) earlyStopOccured = totEpochs < self.goalEpochs #print('hir keys',hir.keys(),lrCb.hir) for obs in hir: rec = [float(x) for x in hir[obs]] self.train_hirD[obs].extend(rec) # this is a hack, 'lr' is returned by fit only when --reduceLr is used if 'lr' not in hir: self.train_hirD['lr'].extend(lrCb.hir) if trlsCb: # add train-loss for Kris nn = len(self.train_hirD['loss']) self.train_hirD['train_loss'] = trlsCb.hir[-nn:] self.train_hirD['stepTimeHist'] = self.inpGenD['train'].stepTime #report performance for the last epoch lossT = self.train_hirD['loss'][-1] lossV = self.train_hirD['val_loss'][-1] lossVbest = min(self.train_hirD['val_loss']) if self.verb: print('end-hpar:') pprint(hpar) print( '\n End Val Loss=%s:%.3f, best:%.3f' % (hpar['train_conf']['lossName'], lossV, lossVbest), ', %d totEpochs, fit=%.1f min, earlyStop=%r' % (totEpochs, fitTime / 60., earlyStopOccured)) self.train_sec = fitTime xxV = np.array(self.train_hirD['loss'][-5:]) trainLoss_avr_last5 = -1 if xxV.shape[0] > 3: trainLoss_avr_last5 = xxV.mean() # add info to summary rec = {} rec['earlyStopOccured'] = int(earlyStopOccured) rec['fitTime_min'] = fitTime / 60. rec['totEpochs'] = totEpochs rec['trainLoss_avr_last5'] = float(trainLoss_avr_last5) rec['train_loss'] = float(lossT) rec['val_loss'] = float(lossV) rec['steps_per_epoch'] = self.inpGenD['train'].__len__() rec['state'] = 'model_trained' rec['rank'] = self.myRank rec['num_ranks'] = self.numRanks rec['num_open_files'] = len(self.inpGenD['train'].conf['cellList']) self.sumRec.update(rec)
def build_model_and_train(config, checkpoint_dir=None, full_config=None, ntrain=None, ntest=None, name=None, seeds=False): from ray import tune from ray.tune.integration.keras import TuneReportCheckpointCallback from raytune.search_space import set_raytune_search_parameters if seeds: # Set seeds for reproducibility random.seed(1234) np.random.seed(1234) tf.random.set_seed(1234) full_config, config_file_stem = parse_config(full_config) if config is not None: full_config = set_raytune_search_parameters(search_space=config, config=full_config) strategy, num_gpus = get_strategy() ds_train, num_train_steps = get_datasets( full_config["train_test_datasets"], full_config, num_gpus, "train") ds_test, num_test_steps = get_datasets(full_config["train_test_datasets"], full_config, num_gpus, "test") ds_val, ds_info = get_heptfds_dataset( full_config["validation_datasets"][0], full_config, num_gpus, "test", full_config["setup"]["num_events_validation"], supervised=False, ) ds_val = ds_val.batch(5) if ntrain: ds_train = ds_train.take(ntrain) num_train_steps = ntrain if ntest: ds_test = ds_test.take(ntest) num_test_steps = ntest print("num_train_steps", num_train_steps) print("num_test_steps", num_test_steps) total_steps = num_train_steps * full_config["setup"]["num_epochs"] print("total_steps", total_steps) callbacks = prepare_callbacks( full_config, tune.get_trial_dir(), ds_val, ) callbacks = callbacks[: -1] # remove the CustomCallback at the end of the list with strategy.scope(): lr_schedule, optim_callbacks = get_lr_schedule(full_config, steps=total_steps) callbacks.append(optim_callbacks) opt = get_optimizer(full_config, lr_schedule) model = make_model(full_config, dtype=tf.dtypes.float32) # Run model once to build the layers model.build((1, full_config["dataset"]["padded_num_elem_size"], full_config["dataset"]["num_input_features"])) full_config = set_config_loss(full_config, full_config["setup"]["trainable"]) configure_model_weights(model, full_config["setup"]["trainable"]) model.build((1, full_config["dataset"]["padded_num_elem_size"], full_config["dataset"]["num_input_features"])) loss_dict, loss_weights = get_loss_dict(full_config) model.compile( loss=loss_dict, optimizer=opt, sample_weight_mode="temporal", loss_weights=loss_weights, metrics={ "cls": [ FlattenedCategoricalAccuracy(name="acc_unweighted", dtype=tf.float64), FlattenedCategoricalAccuracy(use_weights=True, name="acc_weighted", dtype=tf.float64), ] }, ) model.summary() callbacks.append( TuneReportCheckpointCallback(metrics=[ "adam_beta_1", "charge_loss", "cls_acc_unweighted", "cls_loss", "cos_phi_loss", "energy_loss", "eta_loss", "learning_rate", "loss", "pt_loss", "sin_phi_loss", "val_charge_loss", "val_cls_acc_unweighted", "val_cls_acc_weighted", "val_cls_loss", "val_cos_phi_loss", "val_energy_loss", "val_eta_loss", "val_loss", "val_pt_loss", "val_sin_phi_loss", ], ), ) try: model.fit( ds_train.repeat(), validation_data=ds_test.repeat(), epochs=full_config["setup"]["num_epochs"], callbacks=callbacks, steps_per_epoch=num_train_steps, validation_steps=num_test_steps, ) except tf.errors.ResourceExhaustedError: logging.warning( "Resource exhausted, skipping this hyperparameter configuration." ) skiplog_file_path = Path(full_config["raytune"]["local_dir"] ) / name / "skipped_configurations.txt" lines = [ "{}: {}\n".format(item[0], item[1]) for item in config.items() ] with open(skiplog_file_path, "a") as f: f.write("#" * 80 + "\n") for line in lines: f.write(line) logging.warning(line[:-1]) f.write("#" * 80 + "\n\n")