Exemplo n.º 1
0
 def parse_args(self, **kwargs):
     self.kernel = num_as_tuple(cget(kwargs, "kernel", tuple, (1, 1)))
     self.stride = num_as_tuple(cget(kwargs, "stride", tuple, (1, 1)))
     self.pad = num_as_tuple(
         cget(kwargs, "pad", tuple,
              (int(self.kernel[0] / 2), int(self.kernel[1] / 2))))
     self.pool_type = cget(kwargs, "pool_type", str, 'max')
     self.global_pool = cget(kwargs, "global_pool", bool, False)
Exemplo n.º 2
0
 def get_weight_bias(self, **kwargs):
     weight_initializer = cget(kwargs, 'weight_initializer', lambda x: x,
                               None)
     bias_initializer = cget(kwargs, 'bias_initializer', lambda x: x, None)
     self.no_bias = cget(kwargs, 'no_bias', bool, False)
     self.weight = mx.sym.var(self.name + '_weight',
                              init=weight_initializer)
     self.bias = mx.sym.var(
         self.name +
         '_bias', init=bias_initializer) if not self.no_bias else None
     return self.weight, self.bias, self.no_bias
Exemplo n.º 3
0
 def __init__(self, **kwargs):
     self.layout = cget(kwargs, "layout", str, "NCHW")
     self.channel_id = self.layout.find("C")
     self.height_id = self.layout.find("H")
     self.width_id = self.layout.find("W")
     self.time_id = self.layout.find("T")
     self.batch_id = self.layout.find("N")
Exemplo n.º 4
0
 def parse_args(self, **kwargs):
     self.kernel = num_as_tuple(cget(kwargs, "kernel", tuple, (1, 1)))
     self.stride = num_as_tuple(cget(kwargs, "stride", tuple, (1, 1)))
     self.dilate = num_as_tuple(cget(kwargs, "dilate", tuple, (1, 1)))
     self.pad = num_as_tuple(
         cget(kwargs, "pad", tuple,
              (int(self.kernel[0] / 2), int(self.kernel[1] / 2))))
     self.num_group = cget(kwargs, "num_group", int, 1)
     self.num_filter = cget(kwargs, "num_filter", int, 1)
Exemplo n.º 5
0
 def __init__(self, **kwargs):
     self.mode = cget(kwargs, 'mode', str, 'train')
Exemplo n.º 6
0
 def parse_args(self, **kwargs):
     self.eps = cget(kwargs, "eps", float, 2e-5)
     self.fix_gamma = cget(kwargs, "fix_gamma", bool, True)
     self.momentum = cget(kwargs, "momentum", float, .9)
     self.use_global_stats = cget(kwargs, "use_global_stats", bool, False)
Exemplo n.º 7
0
 def parse_args(self, **kwargs):
     self.prob = cget(kwargs, "prob", float, 0.0)