Ejemplo n.º 1
0
    def _build_for_predict(self, ds):
        ds.name = 'predict'
        program = F.Program()
        startup_prog = F.Program()
        with F.program_guard(program, startup_prog):
            #share var with Train net
            log.info('Building Predict Graph')
            fea = ds.features()
            fea = unflatten(fea, ds.data_schema)
            model_spec = _build_net(self.model_fn, fea, RunMode.PREDICT,
                                    self.params, self.run_config)
            log.info('Done')

        optimizer_ops = {'sgd', 'adam', 'adagrad'}
        for op in program.global_block().ops:
            if op.type == 'dropout':
                op._set_attr('is_test', True)
            if op.type == 'batch_norm':
                op._set_attr('is_test', True)
            if op.type in optimizer_ops:
                raise RuntimeError('Found optimizer op in eval graph, op: %s' %
                                   repr(op))
        #program = program.clone(for_test=True)

        log.info(
            'Predict with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
            % (repr(self.run_config), repr(self.params), repr(model_spec)))
        return ProgramPair(train_program=program,
                           startup_program=startup_prog), model_spec
Ejemplo n.º 2
0
    def _build_for_train(self, train_dataset):
        train_dataset.name = 'train'
        train_program = F.Program()
        startup_prog = F.Program()
        with F.program_guard(train_program, startup_prog):
            with collection.Collections() as collections:
                log.info('Building Train Graph...')
                fea = train_dataset.features()
                fea = unflatten(fea, train_dataset.data_schema)
                model_spec = _build_net(self.model_fn, fea, RunMode.TRAIN,
                                        self.params, self.run_config)
                log.info('Building Train Graph: Done')

            scalars = collections.get(collection.Key.SUMMARY_SCALAR)
            histograms = collections.get(collection.Key.SUMMARY_HISTOGRAM)
            skip_optimize_ops = collections.get(collection.Key.SKIP_OPTIMIZE)
            skip_opt = set()
            if skip_optimize_ops is not None:
                skip_opt |= set(skip_optimize_ops)
            if scalars is not None:
                skip_opt |= {t for _, t in scalars}
            if histograms is not None:
                skip_opt |= {t for _, t in histograms}
            skip_opt = list(skip_opt)
        log.info(
            'Train with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
            % (repr(self.run_config), repr(self.params), repr(model_spec)))

        summary_record = SummaryRecord(
            scalar=collections.get(collection.Key.SUMMARY_SCALAR),
            histogram=collections.get(collection.Key.SUMMARY_HISTOGRAM),
        )
        return ProgramPair(
            train_program=train_program,
            startup_program=startup_prog), model_spec, summary_record