def train_mnist(config, checkpoint_dir=None):
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    per_worker_batch_size = 64
    num_workers = get_num_workers()
    global_batch_size = per_worker_batch_size * num_workers
    multi_worker_dataset = mnist_dataset(global_batch_size)
    with strategy.scope():
        multi_worker_model = build_and_compile_cnn_model(config)
    multi_worker_model.fit(multi_worker_dataset,
                           epochs=2,
                           steps_per_epoch=70,
                           callbacks=[
                               TuneReportCheckpointCallback(
                                   {"mean_accuracy": "accuracy"})
                           ])
예제 #2
0
    def train_model(self):
        '''
        ---- This method overrides the one from Deep_CellSpike. It is almost exactly identical
        ---- except the callbacks_list now contains a
        ---- TuneReportCheckpointCallback() object The TuneReportCheckpointCallback
        ---- class handles reporting metrics and checkpoints. The metric it reports is validation loss.

        '''
        if self.verb == 0: print('train silently, myRank=', self.myRank)
        hpar = self.hparams
        callbacks_list = [TuneReportCheckpointCallback(metrics=['val_loss'])]

        if self.useHorovod:
            # Horovod: broadcast initial variable states from rank 0 to all other processes.
            # This is necessary to ensure consistent initialization of all workers when
            # training is started with random weights or restored from a checkpoint.
            callbacks_list.append(
                self.hvd.callbacks.BroadcastGlobalVariablesCallback(0))
            # Note: This callback must be in the list before the ReduceLROnPlateau,
            # TensorBoard or other metrics-based callbacks.
            callbacks_list.append(self.hvd.callbacks.MetricAverageCallback())
            #Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` before
            # the first five epochs. See https://arxiv.org/abs/1706.02677 for details.

            mutliAgent = hpar['train_conf']['multiAgent']
            if mutliAgent['warmup_epochs'] > 0:
                callbacks_list.append(
                    self.hvd.callbacks.LearningRateWarmupCallback(
                        warmup_epochs=mutliAgent['warmup_epochs'],
                        verbose=self.verb))
                if self.verb:
                    print('added LearningRateWarmupCallback(%d epochs)' %
                          mutliAgent['warmup_epochs'])

        lrCb = MyLearningTracker()
        callbacks_list.append(lrCb)

        trlsCb = None
        if self.train_loss_EOE:
            print('enable  end-epoch true train-loss as callBack')
            assert self.cellName == None  # not tested
            genConf = copy.deepcopy(self.sumRec['inpGen']['train'])
            # we need much less stats for the EOE loss:
            genConf['numLocalSamples'] //= 8  # needs less data
            genConf['name'] = 'EOF' + genConf['name']

            inpGen = CellSpike_input_generator(genConf, verb=1)
            trlsCb = MyEpochEndLoss(inpGen)
            callbacks_list.append(trlsCb)

        if self.checkPtOn and self.myRank == 0:
            outF5w = self.outPath + '/' + self.prjName + '.weights_best.h5'
            chkPer = 1
            ckpt = ModelCheckpoint(outF5w,
                                   save_best_only=True,
                                   save_weights_only=True,
                                   verbose=1,
                                   period=chkPer,
                                   monitor='val_loss')
            callbacks_list.append(ckpt)
            if self.verb: print('enabled ModelCheckpoint, save_freq=', chkPer)

        LRconf = hpar['train_conf']['LRconf']
        if LRconf['patience'] > 0:
            [pati, fact] = LRconf['patience'], LRconf['reduceFactor']
            redu_lr = ReduceLROnPlateau(monitor='val_loss',
                                        factor=fact,
                                        patience=pati,
                                        min_lr=0.0,
                                        verbose=self.verb,
                                        min_delta=LRconf['min_delta'])
            callbacks_list.append(redu_lr)
            if self.verb:
                print('enabled ReduceLROnPlateau, patience=%d, factor =%.2f' %
                      (pati, fact))

        if self.earlyStopPatience > 0:
            earlyStop = EarlyStopping(monitor='val_loss',
                                      patience=self.earlyStopPatience,
                                      verbose=self.verb,
                                      min_delta=LRconf['min_delta'])
            callbacks_list.append(earlyStop)
            if self.verb:
                print('enabled EarlyStopping, patience=',
                      self.earlyStopPatience)

        #pprint(hpar)
        if self.verb:
            print('\nTrain_model  goalEpochs=%d' % (self.goalEpochs),
                  '  modelDesign=', self.modelDesign, 'localBS=',
                  hpar['train_conf']['localBS'], 'globBS=',
                  hpar['train_conf']['localBS'] * self.numRanks)

        fitVerb = 1  # prints live:  [=========>....] - ETA: xx s
        if self.verb == 2: fitVerb = 1
        if self.verb == 0: fitVerb = 2  # prints 1-line summary at epoch end

        if self.numRanks > 1:  # change the logic
            if self.verb:
                fitVerb = 2
            else:
                fitVerb = 0  # keras is silent

        startTm = time.time()
        hir = self.model.fit(  # TF-2.1
            self.inpGenD['train'],
            callbacks=callbacks_list,
            epochs=self.goalEpochs,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False,
            shuffle=self.shuffle_data,
            verbose=fitVerb,
            validation_data=self.inpGenD['val'])
        fitTime = time.time() - startTm
        hir = hir.history
        totEpochs = len(hir['loss'])
        earlyStopOccured = totEpochs < self.goalEpochs

        #print('hir keys',hir.keys(),lrCb.hir)
        for obs in hir:
            rec = [float(x) for x in hir[obs]]
            self.train_hirD[obs].extend(rec)

        # this is a hack, 'lr' is returned by fit only when --reduceLr is used
        if 'lr' not in hir:
            self.train_hirD['lr'].extend(lrCb.hir)

        if trlsCb:  # add train-loss for Kris
            nn = len(self.train_hirD['loss'])
            self.train_hirD['train_loss'] = trlsCb.hir[-nn:]

        self.train_hirD['stepTimeHist'] = self.inpGenD['train'].stepTime

        #report performance for the last epoch
        lossT = self.train_hirD['loss'][-1]
        lossV = self.train_hirD['val_loss'][-1]
        lossVbest = min(self.train_hirD['val_loss'])

        if self.verb:
            print('end-hpar:')
            pprint(hpar)
            print(
                '\n End Val Loss=%s:%.3f, best:%.3f' %
                (hpar['train_conf']['lossName'], lossV, lossVbest),
                ', %d totEpochs, fit=%.1f min, earlyStop=%r' %
                (totEpochs, fitTime / 60., earlyStopOccured))
        self.train_sec = fitTime

        xxV = np.array(self.train_hirD['loss'][-5:])
        trainLoss_avr_last5 = -1
        if xxV.shape[0] > 3: trainLoss_avr_last5 = xxV.mean()

        # add info to summary
        rec = {}
        rec['earlyStopOccured'] = int(earlyStopOccured)
        rec['fitTime_min'] = fitTime / 60.
        rec['totEpochs'] = totEpochs
        rec['trainLoss_avr_last5'] = float(trainLoss_avr_last5)
        rec['train_loss'] = float(lossT)
        rec['val_loss'] = float(lossV)
        rec['steps_per_epoch'] = self.inpGenD['train'].__len__()
        rec['state'] = 'model_trained'
        rec['rank'] = self.myRank
        rec['num_ranks'] = self.numRanks
        rec['num_open_files'] = len(self.inpGenD['train'].conf['cellList'])
        self.sumRec.update(rec)
예제 #3
0
def build_model_and_train(config,
                          checkpoint_dir=None,
                          full_config=None,
                          ntrain=None,
                          ntest=None,
                          name=None,
                          seeds=False):
    from ray import tune
    from ray.tune.integration.keras import TuneReportCheckpointCallback
    from raytune.search_space import set_raytune_search_parameters

    if seeds:
        # Set seeds for reproducibility
        random.seed(1234)
        np.random.seed(1234)
        tf.random.set_seed(1234)

    full_config, config_file_stem = parse_config(full_config)

    if config is not None:
        full_config = set_raytune_search_parameters(search_space=config,
                                                    config=full_config)

    strategy, num_gpus = get_strategy()

    ds_train, num_train_steps = get_datasets(
        full_config["train_test_datasets"], full_config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(full_config["train_test_datasets"],
                                           full_config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        full_config["validation_datasets"][0],
        full_config,
        num_gpus,
        "test",
        full_config["setup"]["num_events_validation"],
        supervised=False,
    )
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * full_config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    callbacks = prepare_callbacks(
        full_config,
        tune.get_trial_dir(),
        ds_val,
    )

    callbacks = callbacks[:
                          -1]  # remove the CustomCallback at the end of the list

    with strategy.scope():
        lr_schedule, optim_callbacks = get_lr_schedule(full_config,
                                                       steps=total_steps)
        callbacks.append(optim_callbacks)
        opt = get_optimizer(full_config, lr_schedule)

        model = make_model(full_config, dtype=tf.dtypes.float32)

        # Run model once to build the layers
        model.build((1, full_config["dataset"]["padded_num_elem_size"],
                     full_config["dataset"]["num_input_features"]))

        full_config = set_config_loss(full_config,
                                      full_config["setup"]["trainable"])
        configure_model_weights(model, full_config["setup"]["trainable"])
        model.build((1, full_config["dataset"]["padded_num_elem_size"],
                     full_config["dataset"]["num_input_features"]))

        loss_dict, loss_weights = get_loss_dict(full_config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        callbacks.append(
            TuneReportCheckpointCallback(metrics=[
                "adam_beta_1",
                "charge_loss",
                "cls_acc_unweighted",
                "cls_loss",
                "cos_phi_loss",
                "energy_loss",
                "eta_loss",
                "learning_rate",
                "loss",
                "pt_loss",
                "sin_phi_loss",
                "val_charge_loss",
                "val_cls_acc_unweighted",
                "val_cls_acc_weighted",
                "val_cls_loss",
                "val_cos_phi_loss",
                "val_energy_loss",
                "val_eta_loss",
                "val_loss",
                "val_pt_loss",
                "val_sin_phi_loss",
            ], ), )

        try:
            model.fit(
                ds_train.repeat(),
                validation_data=ds_test.repeat(),
                epochs=full_config["setup"]["num_epochs"],
                callbacks=callbacks,
                steps_per_epoch=num_train_steps,
                validation_steps=num_test_steps,
            )
        except tf.errors.ResourceExhaustedError:
            logging.warning(
                "Resource exhausted, skipping this hyperparameter configuration."
            )
            skiplog_file_path = Path(full_config["raytune"]["local_dir"]
                                     ) / name / "skipped_configurations.txt"
            lines = [
                "{}: {}\n".format(item[0], item[1]) for item in config.items()
            ]

            with open(skiplog_file_path, "a") as f:
                f.write("#" * 80 + "\n")
                for line in lines:
                    f.write(line)
                    logging.warning(line[:-1])
                f.write("#" * 80 + "\n\n")