Beispiel #1
0
from megskull.opr.all import (Conv2D, Pooling2D, FullyConnected, Softmax,
                              CrossEntropyLoss, Dropout, ElementwiseAffine)
from megskull.opr.helper.elemwise_trans import ReLU, Identity
from megskull.graph.query import GroupNode
from megskull.opr.netsrc import DataProvider
import megskull.opr.helper.param_init as pinit
from megskull.opr.helper.param_init import AutoGaussianParamInitializer as G
from megskull.opr.helper.param_init import ConstantParamInitializer as C
from megskull.opr.regularizer import BatchNormalization as BN
import megskull.opr.arith as arith
from megskull.network import NetworkVisitor
import megskull.opr.all as O
import megbrain as mgb
from megskull.utils.logconf import get_logger

logger = get_logger(__name__)


class Shake_fprop(O.NonTrainableMLPOperatorNodeBase):

    opr_attribute = O.NonTrainableMLPOperatorNodeBase.OprAttribute(impure=True)

    _alpha = None

    def __init__(self, name, inpvar, *, alpha=None):
        inpvar = O.as_varnode(inpvar)
        self._alpha = alpha
        if name is None:
            name = "ShakeFprop({})".format(inpvar.name)
        super().__init__(name, inpvar)
Beispiel #2
0
import megbrain as mgb
import megskull
from megskull.utils import logconf
from meghair.utils import io
from meghair.utils.misc import ensure_dir

from neupeak.train.utils import TrainClock
from neupeak.utils.inference import get_fprop_env, FunctionMaker
from neupeak.train.logger.tensorboard_logger import TensorBoardLogger
from neupeak.train.logger.worklog_logger import WorklogLogger, log_rate_limited
from neupeak.utils.fs import change_dir, make_symlink_if_not_exists
from neupeak.dataset.server import create_remote_combiner_dataset_auto_desc as create_remote_dataset
from meghair.train.interaction import parse_devices
from megskull.opr.all import Argmax as argmax

logger = logconf.get_logger(__name__)


def get_inf_iter_from_dataset(ds):
    def get_inf_iter_ds():
        while True:
            yield from ds.get_epoch_minibatch_iter()

    return iter(get_inf_iter_ds())


class Session:
    def __init__(self, config, devices, net=None, train_func=None):
        setproctitle(config.exp_name)

        # log dirs