def trainModel(self, cfg): callbacks = [self.log, \ cb.MyModelCheckpointInterval(cfg), \ cb.MyLearningRateScheduler(cfg), \ cb.SaveLog2File(cfg), \ cb.PrintCallBack()] if cfg.include_eval: callbacks.append(cb.EvaluateTest(self.genTest, m.EvalResults, cfg)) if cfg.include_validation: callbacks.append(cb.MyEarlyStopping(cfg)) callbacks.append(cb.MyModelCheckpointBest(cfg)) self.model.fit_generator(generator = self.genTrain.begin(), \ steps_per_epoch = self.genTrain.nb_batches, \ validation_data = self.genVal.begin(), \ validation_steps = self.genVal.nb_batches, \ epochs = cfg.epoch_end, initial_epoch=cfg.epoch_begin, callbacks=callbacks) else: self.model.fit_generator(generator = self.genTrain.begin(), \ steps_per_epoch = self.genTrain.nb_batches, \ verbose = 2,\ epochs = cfg.epoch_end, initial_epoch=cfg.epoch_begin, callbacks=callbacks)
# data genTrain = DataGenerator(imagesMeta=data.trainGTMeta, cfg=cfg, data_type='train', do_meta=False) # models Models = methods.AllModels(cfg, mode='train', do_hoi=True) _, _, model_hoi = Models.get_models() sys.stdout.flush() #if False: # train callbacks = [callbacks.MyModelCheckpointInterval(cfg), \ callbacks.MyLearningRateScheduler(cfg), \ callbacks.MyModelCheckpointWeightsInterval(cfg),\ callbacks.SaveLog2File(cfg), \ callbacks.PrintCallBack()] if cfg.dataset == 'TUPPMI': model_hoi.fit_generator(generator = genTrain.begin(), \ steps_per_epoch = genTrain.nb_batches, \ verbose = 2,\ epochs = cfg.epoch_end, initial_epoch=cfg.epoch_begin, callbacks=callbacks) else: genTest = DataGenerator(imagesMeta=data.valGTMeta, cfg=cfg, data_type='test', do_meta=False,