def _build_for_predict(self, ds): ds.name = 'predict' program = F.Program() startup_prog = F.Program() with F.program_guard(program, startup_prog): #share var with Train net log.info('Building Predict Graph') fea = ds.features() fea = unflatten(fea, ds.data_schema) model_spec = _build_net(self.model_fn, fea, RunMode.PREDICT, self.params, self.run_config) log.info('Done') optimizer_ops = {'sgd', 'adam', 'adagrad'} for op in program.global_block().ops: if op.type == 'dropout': op._set_attr('is_test', True) if op.type == 'batch_norm': op._set_attr('is_test', True) if op.type in optimizer_ops: raise RuntimeError('Found optimizer op in eval graph, op: %s' % repr(op)) #program = program.clone(for_test=True) log.info( 'Predict with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n' % (repr(self.run_config), repr(self.params), repr(model_spec))) return ProgramPair(train_program=program, startup_program=startup_prog), model_spec
def _build_for_train(self, train_dataset): train_dataset.name = 'train' train_program = F.Program() startup_prog = F.Program() with F.program_guard(train_program, startup_prog): with collection.Collections() as collections: log.info('Building Train Graph...') fea = train_dataset.features() fea = unflatten(fea, train_dataset.data_schema) model_spec = _build_net(self.model_fn, fea, RunMode.TRAIN, self.params, self.run_config) log.info('Building Train Graph: Done') scalars = collections.get(collection.Key.SUMMARY_SCALAR) histograms = collections.get(collection.Key.SUMMARY_HISTOGRAM) skip_optimize_ops = collections.get(collection.Key.SKIP_OPTIMIZE) skip_opt = set() if skip_optimize_ops is not None: skip_opt |= set(skip_optimize_ops) if scalars is not None: skip_opt |= {t for _, t in scalars} if histograms is not None: skip_opt |= {t for _, t in histograms} skip_opt = list(skip_opt) log.info( 'Train with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n' % (repr(self.run_config), repr(self.params), repr(model_spec))) summary_record = SummaryRecord( scalar=collections.get(collection.Key.SUMMARY_SCALAR), histogram=collections.get(collection.Key.SUMMARY_HISTOGRAM), ) return ProgramPair( train_program=train_program, startup_program=startup_prog), model_spec, summary_record