Esempio n. 1
0
    def forward(self, data_batch, is_train=None):
        assert self.binded and self.params_initialized

        # get current_shapes
        if self._curr_module.label_shapes is not None:
            current_shapes = dict(self._curr_module.data_shapes + self._curr_module.label_shapes)
        else:
            current_shapes = dict(self._curr_module.data_shapes)

        # get input_shapes
        if data_batch.provide_label is not None:
            input_shapes = dict(data_batch.provide_data + data_batch.provide_label)
        else:
            input_shapes = dict(data_batch.provide_data)

        # decide if shape changed
        shape_changed = False
        for k, v in current_shapes.items():
            if v != input_shapes[k]:
                shape_changed = True

        if shape_changed:
            module = Module(self._symbol, self._data_names, self._label_names,
                            logger=self.logger, context=self._context,
                            work_load_list=self._work_load_list,
                            fixed_param_names=self._fixed_param_names)
            module.bind(data_batch.provide_data, data_batch.provide_label, self._curr_module.for_training,
                        self._curr_module.inputs_need_grad, force_rebind=False,
                        shared_module=self._curr_module)
            self._curr_module = module

        self._curr_module.forward(data_batch, is_train=is_train)
Esempio n. 2
0
    def forward(self, data_batch, is_train=None):
        assert self.binded and self.params_initialized

        # get current_shapes
        if self._curr_module.label_shapes is not None:
            current_shapes = dict(self._curr_module.data_shapes + self._curr_module.label_shapes)
        else:
            current_shapes = dict(self._curr_module.data_shapes)

        # get input_shapes
        if data_batch.provide_label is not None:
            input_shapes = dict(data_batch.provide_data + data_batch.provide_label)
        else:
            input_shapes = dict(data_batch.provide_data)

        # decide if shape changed
        shape_changed = False
        for k, v in current_shapes.items():
            if v != input_shapes[k]:
                shape_changed = True

        if shape_changed:
            module = Module(self._symbol, self._data_names, self._label_names,
                            logger=self.logger, context=self._context,
                            work_load_list=self._work_load_list,
                            fixed_param_names=self._fixed_param_names)
            module.bind(data_batch.provide_data, data_batch.provide_label, self._curr_module.for_training,
                        self._curr_module.inputs_need_grad, force_rebind=False,
                        shared_module=self._curr_module)
            self._curr_module = module

        self._curr_module.forward(data_batch, is_train=is_train)
Esempio n. 3
0
    def forward(self, data_batch, is_train=None):
        assert self.binded and self.params_initialized

        # get current_shapes
        if self._curr_module.label_shapes is not None:
            print self._curr_module.data_shapes
            print self._curr_module.label_shapes
            print data_batch.provide_data
            print data_batch.provide_label
            current_shapes = [
                dict(self._curr_module.data_shapes[i] +
                     self._curr_module.label_shapes[i])
                for i in xrange(len(self._context))
            ]
        else:
            current_shapes = [
                dict(self._curr_module.data_shapes[i])
                for i in xrange(len(self._context))
            ]

        # get input_shapes
        if is_train:
            input_shapes = [
                dict(data_batch.provide_data[i] + data_batch.provide_label[i])
                for i in xrange(len(self._context))
            ]
        else:
            input_shapes = [
                dict(data_batch.provide_data[i])
                for i in xrange(len(data_batch.provide_data))
            ]

        # decide if shape changed
        shape_changed = len(current_shapes) != len(input_shapes)
        for pre, cur in zip(current_shapes, input_shapes):
            for k, v in pre.items():
                if v != cur[k]:
                    shape_changed = True

        if shape_changed:
            # self._curr_module.reshape(data_batch.provide_data, data_batch.provide_label)
            module = Module(self._symbol,
                            self._data_names,
                            self._label_names,
                            logger=self.logger,
                            context=[
                                self._context[i]
                                for i in xrange(len(data_batch.provide_data))
                            ],
                            work_load_list=self._work_load_list,
                            fixed_param_names=self._fixed_param_names)
            module.bind(data_batch.provide_data,
                        data_batch.provide_label,
                        self._curr_module.for_training,
                        self._curr_module.inputs_need_grad,
                        force_rebind=False,
                        shared_module=self._curr_module)
            self._curr_module = module

        self._curr_module.forward(data_batch, is_train=is_train)
Esempio n. 4
0
    def bind(self, data_shapes, label_shapes=None, for_training=True,
             inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req="write"):
        # in case we already initialized params, keep it
        if self.params_initialized:
            arg_params, aux_params = self.get_params()

        # force rebinding is typically used when one want to switch from
        # training to prediction phase.
        if force_rebind:
            self._reset_bind()

        if self.binded:
            self.logger.warning('Already bound, ignoring bind()')
            return

        assert shared_module is None, 'shared_module for KTModule is not supported'

        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad
        self.binded = True

        module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
                        context=self._context, work_load_list=self._work_load_list)
        module.bind(self._data_shapes, self._label_shapes, for_training, inputs_need_grad,
                    force_rebind=False, shared_module=None)
        self._curr_module = module

        # copy back saved params, if already initialized
        if self.params_initialized:
            self.set_params(arg_params, aux_params)
    def __init__(self,
                 symbol,
                 data_names,
                 label_names,
                 logger=logging,
                 context=ctx.cpu(),
                 work_load_list=None,
                 asymbol=None,
                 args=None):
        super(ParallModule, self).__init__(logger=logger)
        self._symbol = symbol
        self._asymbol = asymbol
        self._data_names = data_names
        self._label_names = label_names
        self._context = context
        self._work_load_list = work_load_list
        self._num_classes = config.num_classes
        self._batch_size = args.batch_size
        self._verbose = args.verbose
        self._emb_size = config.emb_size
        self._local_class_start = args.local_class_start
        assert self._local_class_start == 0
        self._iter = 0

        self._backbone_module = None

        self._num_workers = config.num_workers
        self._num_ctx = len(self._context)
        self._ctx_num_classes = args.ctx_num_classes
        self._nd_cache = {}
        self._ctx_single_gpu = self._context[-1]
        self._fixed_param_names = None
        self._backbone_module = Module(
            self._symbol,
            self._data_names,
            self._label_names,
            logger=self.logger,
            context=self._context,
            work_load_list=self._work_load_list,
            fixed_param_names=self._fixed_param_names)
        self._arcface_modules = []
        self._ctx_class_start = []
        for i in range(len(self._context)):
            args._ctxid = i
            _module = Module(self._asymbol(args),
                             self._data_names,
                             self._label_names,
                             logger=self.logger,
                             context=self._context[i],
                             work_load_list=self._work_load_list,
                             fixed_param_names=self._fixed_param_names)
            self._arcface_modules.append(_module)
            _c = args.local_class_start + i * args.ctx_num_classes
            self._ctx_class_start.append(_c)
        self._usekv = False

        if self._usekv:
            self._distkv = mx.kvstore.create('dist_sync')
            self._kvinit = {}
    def bind(self,
             data_shapes,
             label_shapes=None,
             for_training=True,
             inputs_need_grad=False,
             force_rebind=False,
             shared_module=None):
        # in case we already initialized params, keep it
        if self.params_initialized:
            arg_params, aux_params = self.get_params()

        # force rebinding is typically used when one want to switch from
        # training to prediction phase.
        if force_rebind:
            self._reset_bind()

        if self.binded:
            self.logger.warning('Already binded, ignoring bind()')
            return

        assert shared_module is None, 'shared_module for MutableModule is not supported'

        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad
        self.binded = True

        max_shapes_dict = dict(self._max_data_shapes + self._max_label_shapes)
        max_data_shapes = list()
        for name, shape in data_shapes:
            if name in max_shapes_dict:
                max_data_shapes.append((name, max_shapes_dict[name]))
            else:
                max_data_shapes.append((name, shape))
        max_label_shapes = list()
        for name, shape in label_shapes:
            if name in max_shapes_dict:
                max_label_shapes.append((name, max_shapes_dict[name]))
            else:
                max_label_shapes.append((name, shape))

        module = Module(self._symbol,
                        self._data_names,
                        self._label_names,
                        logger=self.logger,
                        context=self._context,
                        work_load_list=self._work_load_list,
                        fixed_param_names=self._fixed_param_names)
        module.bind(max_data_shapes,
                    max_label_shapes,
                    for_training,
                    inputs_need_grad,
                    force_rebind=False,
                    shared_module=None)
        self._curr_module = module

        # copy back saved params, if already initialized
        if self.params_initialized:
            self.set_params(arg_params, aux_params)
Esempio n. 7
0
    def bind(self, data_shapes, label_shapes=None, for_training=True,
             inputs_need_grad=False, force_rebind=False, shared_module=None):
        # in case we already initialized params, keep it
        if self.params_initialized:
            arg_params, aux_params = self.get_params()

        # force rebinding is typically used when one want to switch from
        # training to prediction phase.
        if force_rebind:
            self._reset_bind()

        if self.binded:
            self.logger.warning('Already binded, ignoring bind()')
            return

        assert shared_module is None, 'shared_module for MutableModule is not supported'

        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad
        self.binded = True

        max_shapes_dict = dict()
        if self._max_data_shapes is not None:
            max_shapes_dict.update(dict(self._max_data_shapes))
        if self._max_label_shapes is not None:
            max_shapes_dict.update(dict(self._max_label_shapes))

        max_data_shapes = list()
        for name, shape in data_shapes:
            if name in max_shapes_dict:
                max_data_shapes.append((name, max_shapes_dict[name]))
            else:
                max_data_shapes.append((name, shape))

        max_label_shapes = list()
        if label_shapes is not None:
            for name, shape in label_shapes:
                if name in max_shapes_dict:
                    max_label_shapes.append((name, max_shapes_dict[name]))
                else:
                    max_label_shapes.append((name, shape))

        if len(max_label_shapes) == 0:
            max_label_shapes = None

        module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
                        context=self._context, work_load_list=self._work_load_list,
                        fixed_param_names=self._fixed_param_names)
        module.bind(max_data_shapes, max_label_shapes, for_training, inputs_need_grad,
                    force_rebind=False, shared_module=None)
        self._curr_module = module

        # copy back saved params, if already initialized
        if self.params_initialized:
            self.set_params(arg_params, aux_params)
Esempio n. 8
0
    def __init__(self, symbol, bn_symbol, batch_size, fc7_model, size,
                 rank, local_rank, memory_bank_list, memory_optimizer,
                 backbone_grad_rescale, memory_lr_scale_list,
                 embedding_size=512, head_num=1, logger=logging, ):
        # configure horovod
        self.memory_lr_scale_list = memory_lr_scale_list
        self.size = size
        self.rank = rank
        self.local_rank = local_rank
        self.gpu = mx.gpu(self.local_rank)
        self.cpu = mx.cpu()                                     # `device_id` is not needed for CPU.
        self.nd_cache = {}
        self.embedding_size = embedding_size
        self.batch_size = batch_size
        self.num_update = 0

        self.batch_end_param = namedtuple(
            'batch_end_param',
            ['loss_list', 'num_epoch_list', 'epoch', 'num_update'])

        self.symbol = symbol
        # self.bn_symbol = bn_symbol
        #
        self.logger = logger
        self.backbone_module = Module(self.symbol,    ['data'], ['softmax_label'], logger=self.logger, context=self.gpu)
        # self.bn_module       = Module(self.bn_symbol, ['data'], None, logger=self.logger, context=self.gpu)
        self.head_num = head_num
        self.memory_bank_list = memory_bank_list
        self.memory_optimizer = memory_optimizer
        self.memory_lr = None
        self.loss_cache = None
        self.grad_cache = None

        assert isinstance(self.memory_bank_list, list)

        # init
        self.fc7_model = fc7_model

        # fp16
        self.backbone_grad_rescale = backbone_grad_rescale

        self.binded = False
        self.for_training = False
        self.inputs_need_grad = False
        self.params_initialized = False
        self.optimizer_initialized = False
        self._total_exec_bytes = 0

        self.global_label = None
Esempio n. 9
0
def train_net(sym, prefix, ctx, pretrained, epoch, begin_epoch, end_epoch, imdb, batch_size, thread_num,
              net=12, with_cls = True, with_bbox = True, with_landmark = False, frequent=50, initialize=True, base_lr=0.01, lr_epoch = [6,14]):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    train_data = ImageLoader(imdb, net, batch_size, thread_num, True, shuffle=True, ctx=ctx)

    if not initialize:
        args, auxs = load_param(pretrained, epoch, convert=True)

    if initialize:
        print "init weights and bias:"
        data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
        arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
        init = mx.init.Xavier(factor_type="in", rnd_type='gaussian', magnitude=2)
        args = dict()
        auxs = dict()
        print 'hello3'
		
        for k in sym.list_arguments():
            if k in data_shape_dict:
                continue

            #print 'init', k

            args[k] = mx.nd.zeros(arg_shape_dict[k])
            init(k, args[k])
            if k.startswith('fc'):
                args[k][:] /= 10

            '''
            if k.endswith('weight'):
                if k.startswith('conv'):
                    args[k] = mx.random.normal(loc=0, scale=0.001, shape=arg_shape_dict[k])
                else:
                    args[k] = mx.random.normal(loc=0, scale=0.01, shape=arg_shape_dict[k])
            else: # bias
                args[k] = mx.nd.zeros(shape=arg_shape_dict[k])
            '''

        for k in sym.list_auxiliary_states():
            auxs[k] = mx.nd.zeros(aux_shape_dict[k])
            #print aux_shape_dict[k]
            init(k, auxs[k])

    lr_factor = 0.1
    image_num = len(imdb)
    
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * image_num / batch_size) for epoch in lr_epoch_diff]
    print 'lr', lr, 'lr_epoch', lr_epoch, 'lr_epoch_diff', lr_epoch_diff
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)

    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]

    batch_end_callback = mx.callback.Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix,period=10)
    eval_metrics = mx.metric.CompositeEvalMetric()
    eval_metrics.add(metric_human14.LANDMARK_MSE())
    eval_metrics.add(metric_human14.LANDMARK_L1())
    
    optimizer_params = {'momentum': 0.9,
                        'wd': 0.00001,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0}

    mod = Module(sym, data_names=data_names, label_names=label_names, logger=logger, context=ctx)
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=args, aux_params=auxs, begin_epoch=begin_epoch, num_epoch=end_epoch)
Esempio n. 10
0
class ParallModule(BaseModule):
    def __init__(self, symbol, data_names, label_names,
                 logger=logging, context=ctx.cpu(), work_load_list=None,
                 asymbol = None,
                 args = None):
        super(ParallModule, self).__init__(logger=logger)
        self._symbol = symbol
        self._asymbol = asymbol
        self._data_names = data_names
        self._label_names = label_names
        self._context = context
        self._work_load_list = work_load_list
        self._num_classes = config.num_classes
        self._batch_size = args.batch_size
        self._verbose = args.verbose
        self._emb_size = config.emb_size
        self._local_class_start = args.local_class_start
        self._iter = 0

        self._curr_module = None

        self._num_workers = config.num_workers
        self._num_ctx = len(self._context)
        self._ctx_num_classes = args.ctx_num_classes
        self._nd_cache = {}
        self._ctx_cpu = mx.cpu()
        self._ctx_single_gpu = self._context[-1]
        self._fixed_param_names = None
        self._curr_module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
                        context=self._context, work_load_list=self._work_load_list,
                        fixed_param_names=self._fixed_param_names)
        self._arcface_modules = []
        self._ctx_class_start = []
        for i in range(len(self._context)):

          args._ctxid = i
          _module = Module(self._asymbol(args), self._data_names, self._label_names, logger=self.logger,
                          context=mx.gpu(i), work_load_list=self._work_load_list,
                          fixed_param_names=self._fixed_param_names)
          self._arcface_modules.append(_module)
          _c = args.local_class_start + i*args.ctx_num_classes
          self._ctx_class_start.append(_c)
        self._usekv = False
        if self._usekv:
          self._distkv = mx.kvstore.create('dist_sync')
          self._kvinit = {}


    def _reset_bind(self):
        self.binded = False
        self._curr_module = None

    @property
    def data_names(self):
        return self._data_names

    @property
    def output_names(self):
        return self._symbol.list_outputs()

    @property
    def data_shapes(self):
        assert self.binded
        return self._curr_module.data_shapes

    @property
    def label_shapes(self):
        assert self.binded
        return self._curr_module.label_shapes

    @property
    def output_shapes(self):
        assert self.binded
        return self._curr_module.output_shapes

    def get_export_params(self):
        assert self.binded and self.params_initialized
        _g, _x = self._curr_module.get_params()
        g = _g.copy()
        x = _x.copy()
        return g, x

    def get_params(self):
        assert self.binded and self.params_initialized
        _g, _x = self._curr_module.get_params()
        g = _g.copy()
        x = _x.copy()
        for _module in self._arcface_modules:
          _g, _x = _module.get_params()
          ag = _g.copy()
          ax = _x.copy()
          g.update(ag)
          x.update(ax)
        return g, x

    def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True,
                   allow_extra=False):
      g = arg_params
      x = aux_params
      #ag = {}
      #ax = {}
      rk = []
      for k in g:
        v = g[k]
        if k.startswith('fc7'):
          p1 = k.find('_')
          p2 = k.rfind('_')
          _ctxid = int(k[p1+1:p2])
          self._arcface_modules[_ctxid].set_params({k:v}, {})
          rk.append(k)
      for k in rk:
        del g[k]
      self._curr_module.set_params(g, x)
      #self._arcface_module.set_params(ag, ax)


    def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
                    allow_missing=False, force_init=False, allow_extra=False):
        if self.params_initialized and not force_init:
            return
        assert self.binded, 'call bind before initializing the parameters'
        #TODO init the same weights with all work nodes
        self._curr_module.init_params(initializer=initializer, arg_params=None,
                                      aux_params=None, allow_missing=allow_missing,
                                      force_init=force_init, allow_extra=allow_extra)
        for _module in self._arcface_modules:
          #_initializer = initializer
          _initializer = mx.init.Normal(0.01)
          _module.init_params(initializer=_initializer, arg_params=None,
                                        aux_params=None, allow_missing=allow_missing,
                                        force_init=force_init, allow_extra=allow_extra)
        self.params_initialized = True


    def bind(self, data_shapes, label_shapes=None, for_training=True,
             inputs_need_grad=False, force_rebind=False, shared_module=None):
        print('in_bind', self.params_initialized, data_shapes, label_shapes)
        if self.params_initialized:
            arg_params, aux_params = self.get_params()

        # force rebinding is typically used when one want to switch from
        # training to prediction phase.
        if force_rebind:
            self._reset_bind()

        if self.binded:
            self.logger.warning('Already binded, ignoring bind()')
            return

        assert shared_module is None, 'shared_module for MutableModule is not supported'
        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad
        self.binded = True
        self._curr_module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
                    force_rebind=False, shared_module=None)
        _data_shape = data_shapes[0][1]
        print('_data_shape', _data_shape, label_shapes)
        for _module in self._arcface_modules:
          _module.bind([('data', (_data_shape[0]*self._num_workers, self._emb_size))], [('softmax_label', (_data_shape[0]*self._num_workers,))], for_training, True,
                      force_rebind=False, shared_module=None)
        if self.params_initialized:
            self.set_params(arg_params, aux_params)

    def init_optimizer(self, kvstore='local', optimizer='sgd',
                       optimizer_params=(('learning_rate', 0.01),), force_init=False):
        assert self.binded and self.params_initialized
        if self.optimizer_initialized and not force_init:
            self.logger.warning('optimizer already initialized, ignoring.')
            return

        self._curr_module.init_optimizer(kvstore, optimizer, optimizer_params,
                                         force_init=force_init)
        for _module in self._arcface_modules:
          _module.init_optimizer(kvstore, optimizer, optimizer_params,
                                           force_init=force_init)
        self.optimizer_initialized = True

    def kv_push(self, key, value):
      #if value.context!=mx.cpu():
      #  value = value.as_in_context(mx.cpu())
      if not key in self._kvinit:
        self._distkv.init(key, nd.zeros_like(value))
        self._kvinit[key] = 1
      self._distkv.push(key, value)

    #get fc1 and partial fc7
    def forward(self, data_batch, is_train=None):
        #g,x = self.get_params()
        #print('{fc7_weight[0][0]}', self._iter, g['fc7_0_weight'].asnumpy()[0][0])
        #print('{pre_fc1_weight[0][0]}', self._iter, g['pre_fc1_weight'].asnumpy()[0][0])


        assert self.binded and self.params_initialized
        self._curr_module.forward(data_batch, is_train=is_train)
        if is_train:
          self._iter+=1
          fc1, label = self._curr_module.get_outputs(merge_multi_context=True)
          global_fc1 = fc1
          self.global_label = label.as_in_context(self._ctx_cpu)


          for i, _module in enumerate(self._arcface_modules):
            _label = self.global_label - self._ctx_class_start[i]
            db_global_fc1 = io.DataBatch([global_fc1], [_label])
            _module.forward(db_global_fc1) #fc7 with margin
        #print('forward end')


    def get_ndarray(self, context, name, shape):
      key = "%s_%s"%(name, context)
      #print(key)
      if not key in self._nd_cache:
        v = nd.zeros( shape=shape, ctx = context)
        self._nd_cache[key] = v
      else:
        v = self._nd_cache[key]
      return v

    def get_ndarray2(self, context, name, arr):
      key = "%s_%s"%(name, context)
      #print(key)
      if not key in self._nd_cache:
        v = nd.zeros( shape=arr.shape, ctx = context)
        self._nd_cache[key] = v
      else:
        v = self._nd_cache[key]
      arr.copyto(v)
      return v

    def backward(self, out_grads=None):
        #print('in backward')
        assert self.binded and self.params_initialized
        #tmp_ctx = self._ctx_cpu
        tmp_ctx = self._ctx_single_gpu
        fc7_outs = []
        ctx_fc7_max = self.get_ndarray(tmp_ctx, 'ctx_fc7_max', (self._batch_size, len(self._context)))
        #local_fc7_max = nd.zeros( (self.global_label.shape[0],1), ctx=mx.cpu())
        arcface_module_outputs = []
        for i, _module in enumerate(self._arcface_modules):
          #_fc7 = _module.get_outputs(merge_multi_context=True)[0]
          out = _module.get_outputs(merge_multi_context=True)
          #print(out[0].shape)
          #print(out[1].shape)
          arcface_module_outputs.append(out)
          _fc7 = out[0]
          fc7_outs.append(_fc7)
          _fc7_max = nd.max(_fc7, axis=1).as_in_context(tmp_ctx)
          ctx_fc7_max[:,i] = _fc7_max

        local_fc7_max = self.get_ndarray(tmp_ctx, 'local_fc7_max', (self._batch_size, 1))
        nd.max(ctx_fc7_max, axis=1, keepdims=True, out=local_fc7_max)
        global_fc7_max = local_fc7_max
        #local_fc7_sum = None
        local_fc7_sum = self.get_ndarray(tmp_ctx, 'local_fc7_sum', (self._batch_size,1))
        local_fc7_sum[:,:] = 0.0
        for i, _module in enumerate(self._arcface_modules):
          _max = self.get_ndarray2(fc7_outs[i].context, 'fc7_max', global_fc7_max)
          fc7_outs[i] = nd.broadcast_sub(fc7_outs[i], _max)
          fc7_outs[i] = nd.exp(fc7_outs[i])
          _sum = nd.sum(fc7_outs[i], axis=1, keepdims=True).as_in_context(tmp_ctx)
          local_fc7_sum += _sum
        global_fc7_sum = local_fc7_sum

        if self._iter%self._verbose==0:
          #_ctx = self._context[-1]
          _ctx = self._ctx_cpu
          _probs = []
          for i, _module in enumerate(self._arcface_modules):
            _prob = self.get_ndarray2(_ctx, '_fc7_prob_%d'%i, fc7_outs[i])
            _probs.append(_prob)
          fc7_prob = self.get_ndarray(_ctx, 'test_fc7_prob', (self._batch_size, self._ctx_num_classes*len(self._context)))
          nd.concat(*_probs, dim=1, out=fc7_prob)
          fc7_pred = nd.argmax(fc7_prob, axis=1)
          local_label = self.global_label - self._local_class_start
          #local_label = self.get_ndarray2(_ctx, 'test_label', local_label)
          _pred = nd.equal(fc7_pred, local_label)
          print('{fc7_acc}', self._iter, nd.mean(_pred).asnumpy()[0])


        #local_fc1_grad = []
        #fc1_grad_ctx = self._ctx_cpu
        fc1_grad_ctx = self._ctx_single_gpu
        local_fc1_grad = self.get_ndarray(fc1_grad_ctx, 'local_fc1_grad', (self._batch_size,self._emb_size))
        local_fc1_grad[:,:] = 0.0
        total_eloss = []
        celoss_verbose = 1000
        if self._iter%celoss_verbose==0:
          fc7_celoss = self.get_ndarray(tmp_ctx, 'test_fc7_celoss', (self._batch_size,))
          fc7_celoss[:] = 0.0

        for i, _module in enumerate(self._arcface_modules):
          _sum = self.get_ndarray2(fc7_outs[i].context, 'fc7_sum', global_fc7_sum)
          fc7_outs[i] = nd.broadcast_div(fc7_outs[i], _sum)
          a = i*self._ctx_num_classes
          b = (i+1)*self._ctx_num_classes
          _label = self.global_label - self._ctx_class_start[i]
          _label = self.get_ndarray2(fc7_outs[i].context, 'label', _label)
          onehot_label = self.get_ndarray(fc7_outs[i].context, 'label_onehot', (self._batch_size, self._ctx_num_classes))
          nd.one_hot(_label, depth=self._ctx_num_classes, on_value = 1.0, off_value = 0.0, out=onehot_label)
          #print(fc7_outs[i].shape, onehot_label.shape)

          if self._iter%celoss_verbose==0:
            _ce_loss = fc7_outs[i] * onehot_label
            _ce_loss = nd.sum(_ce_loss, axis=1)
            fc7_celoss += _ce_loss.as_in_context(tmp_ctx)
          fc7_outs[i] -= onehot_label

          out = arcface_module_outputs[i]
          out_grads = [fc7_outs[i]]
          for j in range(1, len(out)):
              eloss = out[j]
              #print('eloss%d:'%j, eloss.shape)
              #print(out_grads[0].shape)
              #egrad_shape = (out_grads[0].shape[0], eloss.shape[0])
              egrad_shape = eloss.shape
              egrad = self.get_ndarray(fc7_outs[i].context, 'egrad%d'%j, egrad_shape)
              #egrad[:][:] = 1.0/egrad_shape[0]
              egrad[:][:] = 1.0
              out_grads.append(egrad)
              if self._iter%self._verbose==0:
                  total_eloss.append(np.mean(eloss.asnumpy()))

          _module.backward(out_grads = out_grads)
          #ctx_fc1_grad = _module.get_input_grads()[0].as_in_context(mx.cpu())
          ctx_fc1_grad = self.get_ndarray2(fc1_grad_ctx, 'ctx_fc1_grad_%d'%i, _module.get_input_grads()[0])
          local_fc1_grad += ctx_fc1_grad

        if self._iter%self._verbose==0 and len(total_eloss)>0:
          print('{eloss}', self._iter, np.mean(total_eloss))
        #if self._iter%self._verbose==0:
        if self._iter%celoss_verbose==0:
          ce_loss = nd.log(fc7_celoss) * -1.0
          ce_loss = nd.mean(ce_loss)
          print('CELOSS,%d,%f'% (self._iter, ce_loss.asnumpy()))

        global_fc1_grad = local_fc1_grad
        self._curr_module.backward(out_grads = [global_fc1_grad])


    def update(self):
        assert self.binded and self.params_initialized and self.optimizer_initialized
        self._curr_module.update()
        for i, _module in enumerate(self._arcface_modules):
          _module.update()
        mx.nd.waitall()


    def get_outputs(self, merge_multi_context=True):
        assert self.binded and self.params_initialized
        return self._curr_module.get_outputs(merge_multi_context=merge_multi_context)
        #return self._arcface_module.get_outputs(merge_multi_context=merge_multi_context)

    def get_input_grads(self, merge_multi_context=True):
        assert self.binded and self.params_initialized and self.inputs_need_grad
        return self._curr_module.get_input_grads(merge_multi_context=merge_multi_context)

    def update_metric(self, eval_metric, labels):
        assert self.binded and self.params_initialized
        #self._curr_module.update_metric(eval_metric, labels)
        #label = labels[0]
        #print(label.shape)
        #self._arcface_module.update_metric(eval_metric, labels)

    def install_monitor(self, mon):
        """ Install monitor on all executors """
        assert self.binded
        self._curr_module.install_monitor(mon)

    def forward_backward(self, data_batch):
        """A convenient function that calls both ``forward`` and ``backward``."""
        self.forward(data_batch, is_train=True) # get fc1 and partial fc7
        self.backward()

    def fit(self, train_data, eval_data=None, eval_metric='acc',
            epoch_end_callback=None, batch_end_callback=None, kvstore='local',
            optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
            eval_end_callback=None,
            eval_batch_end_callback=None, initializer=Uniform(0.01),
            arg_params=None, aux_params=None, allow_missing=False,
            force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
            validation_metric=None, monitor=None, sparse_row_id_fn=None):
        """Trains the module parameters.

        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.

        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.
        sparse_row_id_fn : A callback function
            The function  takes `data_batch` as an input and returns a dict of
            str -> NDArray. The resulting dict is used for pulling row_sparse
            parameters from the kvstore, where the str key is the name of the param,
            and the value is the row id of the param to pull.

        Examples
        --------
        >>> # An example of using fit for training.
        >>> # Assume training dataIter and validation dataIter are ready
        >>> # Assume loading a previously checkpointed model
        >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
        >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
        ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
        ...     arg_params=arg_params, aux_params=aux_params,
        ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
        """
        assert num_epoch is not None, 'please specify number of epochs'
        assert arg_params is None and aux_params is None

        self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                  for_training=True, force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=allow_missing, force_init=force_init)
        self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)
        epoch_eval_metric = copy.deepcopy(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            epoch_eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)
                self.update()
                assert not isinstance(data_batch, list)

                #if isinstance(data_batch, list):
                #    #print('XXX')
                #    self.update_metric(eval_metric,
                #                       [db.label for db in data_batch],
                #                       pre_sliced=True)
                #    self.update_metric(epoch_eval_metric,
                #                       [db.label for db in data_batch],
                #                       pre_sliced=True)
                #else:
                #    #print('before update metric')
                #    self.update_metric(eval_metric, data_batch.label)
                #    self.update_metric(epoch_eval_metric, data_batch.label)
                #labels = data_batch.label
                #labels = [self.global_label]
                #self.update_metric(eval_metric, labels)
                #self.update_metric(epoch_eval_metric, labels)

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
                except StopIteration:
                    end_of_batch = True

                if monitor is not None:
                    monitor.toc_print()

                #if end_of_batch:
                #    eval_name_vals = epoch_eval_metric.get_name_value()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                     eval_metric=None,
                                                     locals=locals())
                    batch_end_callback(batch_end_params)
                    #for callback in _as_list(batch_end_callback):
                    #    callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            #for name, val in eval_name_vals:
            #    self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Esempio n. 11
0
    def bind(self,
             data_shapes,
             label_shapes=None,
             for_training=True,
             inputs_need_grad=False,
             force_rebind=False,
             shared_module=None):
        # in case we already initialized params, keep it
        if self.params_initialized:
            arg_params, aux_params = self.get_params()

        # force rebinding is typically used when one want to switch from
        # training to prediction phase.
        if force_rebind:
            self._reset_bind()

        if self.binded:
            self.logger.warning('Already binded, ignoring bind()')
            return

        assert shared_module is None, 'shared_module for MutableModule is not supported'

        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad
        self.binded = True

        max_shapes_dict = dict()
        if self._max_data_shapes is not None:
            # dict1.update(dict2) 用于把 字典2中的键值对更新到 字典1 中
            max_shapes_dict.update(dict(self._max_data_shapes))
        if self._max_label_shapes is not None:
            max_shapes_dict.update(dict(self._max_label_shapes))

        max_data_shapes = list()
        for name, shape in data_shapes:
            if name in max_shapes_dict:
                max_data_shapes.append((name, max_shapes_dict[name]))
            else:
                max_data_shapes.append((name, shape))

        max_label_shapes = list()
        if label_shapes is not None:
            for name, shape in label_shapes:
                if name in max_shapes_dict:
                    max_label_shapes.append((name, max_shapes_dict[name]))
                else:
                    max_label_shapes.append((name, shape))

        if len(max_label_shapes) == 0:
            max_label_shapes = None
        # referen: https://mxnet.incubator.apache.org/api/python/module/module.html#mxnet.module.Module
        # 初始化 模型
        module = Module(self._symbol,
                        self._data_names,
                        self._label_names,
                        logger=self.logger,
                        context=self._context,
                        work_load_list=self._work_load_list,
                        fixed_param_names=self._fixed_param_names)
        module.bind(max_data_shapes,
                    max_label_shapes,
                    for_training,
                    inputs_need_grad,
                    force_rebind=False,
                    shared_module=None)
        self._curr_module = module

        # copy back saved params, if already initialized
        if self.params_initialized:
            self.set_params(arg_params, aux_params)
Esempio n. 12
0
def train_net(sym, prefix, ctx, pretrained, epoch, begin_epoch, end_epoch, imdb,
              net=12, frequent=50, initialize=True, base_lr=0.01):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    train_data = ImageLoader(imdb, net, config.BATCH_SIZE, shuffle=True, ctx=ctx)

    if not initialize:
        args, auxs = load_param(pretrained, epoch, convert=True)

    if initialize:
        print("init weights and bias:")
        data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
        arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
        init = mx.init.Xavier(factor_type="in", rnd_type='gaussian', magnitude=2)
        args = dict()
        auxs = dict()

        for k in sym.list_arguments():
            if k in data_shape_dict:
                continue

            print('init', k)

            args[k] = mx.nd.zeros(arg_shape_dict[k])
            init(k, args[k])
            if k.startswith('fc'):
                args[k][:] /= 10

            '''
            if k.endswith('weight'):
                if k.startswith('conv'):
                    args[k] = mx.random.normal(loc=0, scale=0.001, shape=arg_shape_dict[k])
                else:
                    args[k] = mx.random.normal(loc=0, scale=0.01, shape=arg_shape_dict[k])
            else: # bias
                args[k] = mx.nd.zeros(shape=arg_shape_dict[k])
            '''

        for k in sym.list_auxiliary_states():
            auxs[k] = mx.nd.zeros()
            init(k, auxs[k])

    lr_factor = 0.1
    lr_epoch = config.LR_EPOCH
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(imdb) / config.BATCH_SIZE) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch', lr_epoch, 'lr_epoch_diff', lr_epoch_diff)
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)

    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]

    batch_end_callback = mx.callback.Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    eval_metrics = mx.metric.CompositeEvalMetric()
    metric1 = metric.Accuracy()
    metric2 = metric.LogLoss()
    metric3 = metric.BBOX_MSE()
    for child_metric in [metric1, metric2, metric3]:
        eval_metrics.add(child_metric)
    optimizer_params = {'momentum': 0.9,
                        'wd': 0.00001,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0}

    mod = Module(sym, data_names=data_names, label_names=label_names, logger=logger, context=ctx)
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=args, aux_params=auxs, begin_epoch=begin_epoch, num_epoch=end_epoch)
Esempio n. 13
0
def train_net(mode,
              sym,
              prefix,
              ctx,
              pretrained,
              epoch,
              begin_epoch,
              end_epoch,
              imdb,
              batch_size,
              thread_num,
              im_size,
              net=112,
              frequent=50,
              initialize=True,
              base_lr=0.01,
              lr_epoch=[6, 14]):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    train_data = ImageLoader(imdb,
                             net,
                             batch_size,
                             thread_num,
                             shuffle=True,
                             ctx=ctx)

    if not initialize:
        args, auxs = load_param(pretrained, epoch, convert=True)

    if initialize:
        print "init weights and bias:"
        data_shape_dict = dict(train_data.provide_data +
                               train_data.provide_label)
        print(data_shape_dict)
        arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
        #print(arg_shape)
        #print(aux_shape)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
        init = mx.init.Xavier(factor_type="in",
                              rnd_type='gaussian',
                              magnitude=2)
        args = dict()
        auxs = dict()
        #print 'hello3'

        for k in sym.list_arguments():
            if k in data_shape_dict:
                continue

            #print 'init', k

            args[k] = mx.nd.zeros(arg_shape_dict[k])
            init(k, args[k])
            if k.startswith('fc'):
                args[k][:] /= 10

        for k in sym.list_auxiliary_states():
            auxs[k] = mx.nd.zeros(aux_shape_dict[k])
            #print aux_shape_dict[k]
            init(k, auxs[k])

    lr_factor = 0.1
    #lr_epoch = config.LR_EPOCH
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(imdb) / batch_size) for epoch in lr_epoch_diff]
    print 'lr', lr, 'lr_epoch', lr_epoch, 'lr_epoch_diff', lr_epoch_diff
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)

    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]

    batch_end_callback = mx.callback.Speedometer(train_data.batch_size,
                                                 frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    eval_metrics = mx.metric.CompositeEvalMetric()

    metric1 = metric.GenderAccuracy()
    metric2 = metric.GenderLogLoss()
    if mode == "gender_age":
        metric3 = metric.AGE_MAE()
        for child_metric in [metric1, metric2, metric3]:
            eval_metrics.add(child_metric)
    else:
        for child_metric in [metric1, metric2]:
            eval_metrics.add(child_metric)
    #eval_metrics = mx.metric.CompositeEvalMetric([metric.AccMetric(), metric.MAEMetric(), metric.CUMMetric()])
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.00001,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0
    }

    mod = Module(sym,
                 data_names=data_names,
                 label_names=label_names,
                 logger=logger,
                 context=ctx)
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=args,
            aux_params=auxs,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Esempio n. 14
0
def train_net(sym,
              prefix,
              ctx,
              pretrained,
              epoch,
              begin_epoch,
              end_epoch,
              imdb,
              net=12,
              frequent=50,
              initialize=True,
              base_lr=0.01):

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)  # 记录到标准输出

    # 训练数据
    train_data = ImageLoader(imdb,
                             net,
                             config.BATCH_SIZE,
                             shuffle=True,
                             ctx=ctx)

    if not initialize:  # 如果非初始化 加载参数
        args, auxs = load_param(pretrained, epoch, convert=True)

    if initialize:
        print("init weights and bias:")
        data_shape_dict = dict(train_data.provide_data +
                               train_data.provide_label)
        arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))

        # 权重初始化 Xavier初始化器
        init = mx.init.Xavier(factor_type="in",
                              rnd_type='gaussian',
                              magnitude=2)
        args = dict()  # 模型参数以及网络权重字典
        auxs = dict()  # 模型参数以及一些附加状态的字典

        for k in sym.list_arguments():
            if k in data_shape_dict:
                continue

            print('init', k)

            args[k] = mx.nd.zeros(arg_shape_dict[k])
            init(k, args[k])
            if k.startswith('fc'):
                args[k][:] /= 10
            '''
            if k.endswith('weight'):
                if k.startswith('conv'):
                    args[k] = mx.random.normal(loc=0, scale=0.001, shape=arg_shape_dict[k])
                else:
                    args[k] = mx.random.normal(loc=0, scale=0.01, shape=arg_shape_dict[k])
            else: # bias
                args[k] = mx.nd.zeros(shape=arg_shape_dict[k])
            '''

        for k in sym.list_auxiliary_states():
            auxs[k] = mx.nd.zeros()
            init(k, auxs[k])

    lr_factor = 0.1
    lr_epoch = config.LR_EPOCH
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(imdb) / config.BATCH_SIZE) for epoch in lr_epoch_diff
    ]
    print('lr:{},lr_epoch:{},lr_epoch_diff:{}'.format(lr, lr_epoch,
                                                      lr_epoch_diff))
    # print('lr', lr, 'lr_epoch', lr_epoch, 'lr_epoch_diff', lr_epoch_diff)

    # MXNet设置动态学习率,经过lr_iters次更新后,学习率变为lr*lr_factor
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)

    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]

    # 作用是每隔多少个batch显示一次结果
    batch_end_callback = mx.callback.Speedometer(train_data.batch_size,
                                                 frequent=frequent)
    # 作用是每隔period个epoch保存训练得到的模型
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    # 调用评价函数类
    eval_metrics = mx.metric.CompositeEvalMetric()
    metric1 = metric.Accuracy()
    metric2 = metric.LogLoss()
    metric3 = metric.BBOX_MSE()
    # 使用add方法添加评价函数类
    for child_metric in [metric1, metric2, metric3]:
        eval_metrics.add(child_metric)
    # 优化相关参数
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.00001,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': 5
    }
    # 创建一个可训练的模块
    mod = Module(sym,
                 data_names=data_names,
                 label_names=label_names,
                 logger=logger,
                 context=ctx)
    # 训练模型
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=args,
            aux_params=auxs,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Esempio n. 15
0
class MemoryModuleGlint(object):
    def __init__(self, symbol, bn_symbol, batch_size, fc7_model, size,
                 rank, local_rank, memory_bank_list, memory_optimizer,
                 backbone_grad_rescale, memory_lr_scale_list,
                 embedding_size=512, head_num=1, logger=logging, ):
        # configure horovod
        self.memory_lr_scale_list = memory_lr_scale_list
        self.size = size
        self.rank = rank
        self.local_rank = local_rank
        self.gpu = mx.gpu(self.local_rank)
        self.cpu = mx.cpu()                                     # `device_id` is not needed for CPU.
        self.nd_cache = {}
        self.embedding_size = embedding_size
        self.batch_size = batch_size
        self.num_update = 0

        self.batch_end_param = namedtuple(
            'batch_end_param',
            ['loss_list', 'num_epoch_list', 'epoch', 'num_update'])

        self.symbol = symbol
        # self.bn_symbol = bn_symbol
        #
        self.logger = logger
        self.backbone_module = Module(self.symbol,    ['data'], ['softmax_label'], logger=self.logger, context=self.gpu)
        # self.bn_module       = Module(self.bn_symbol, ['data'], None, logger=self.logger, context=self.gpu)
        self.head_num = head_num
        self.memory_bank_list = memory_bank_list
        self.memory_optimizer = memory_optimizer
        self.memory_lr = None
        self.loss_cache = None
        self.grad_cache = None

        assert isinstance(self.memory_bank_list, list)

        # init
        self.fc7_model = fc7_model

        # fp16
        self.backbone_grad_rescale = backbone_grad_rescale

        self.binded = False
        self.for_training = False
        self.inputs_need_grad = False
        self.params_initialized = False
        self.optimizer_initialized = False
        self._total_exec_bytes = 0

        self.global_label = None

    def forward_backward(self, data_batch):
        total_feature, total_label = self.forward(data_batch, is_train=True)
        self.backward_all(total_feature, total_label)

    @staticmethod
    def sync_aux_params(params):
        pass

    @staticmethod
    def broadcast_parameters(params):
        rank_0_dict = {}
        # Run broadcasts.
        for key, tensor in params.items():
            rank_0_dict[key] = hvd.broadcast(tensor, 0, key)
        return rank_0_dict

    @staticmethod
    def combine(data_batches):
        assert isinstance(data_batches, list), "data_batches must be a list."
        length = len(data_batches)
        total_data = [data_batches[0].data[0].reshape(0, -1)]
        total_label = [data_batches[0].label[0].reshape(0, 1)]
        data_shape = data_batches[0].data[0].shape
        if length > 1:
            for i in range(1, length):
                assert data_batches[i].data[0].shape[0] == data_batches[0].data[0].shape[0]
                total_data.append(data_batches[i].data[0].reshape(0, -1))
                total_label.append(data_batches[i].label[0].reshape(0, 1))
        # shuffle
        total_data = mx.nd.concat(*total_data, dim=1)
        total_data = total_data.reshape(-1, data_shape[1], data_shape[2], data_shape[3])
        total_label = mx.nd.concat(*total_label, dim=1)
        total_label = total_label.reshape(-1)
        return mx.io.DataBatch([total_data], [total_label])

    def fit(self, train_data_list, optimizer_params, batch_end_callback=None, kvstore='local',
            initializer=Uniform(0.01),
            arg_params=None, aux_params=None, allow_missing=False,
            force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None):

        assert num_epoch is not None, 'please specify number of epochs'
        assert arg_params is None and aux_params is None

        provide_data_list = []
        provide_label_list = []
        for td in train_data_list:
            provide_data_list.append(td.provide_data)
            provide_label_list.append(td.provide_label)

        self.bind(data_shapes_list=provide_data_list, label_shapes_list=provide_label_list,
                  for_training=True)

        self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=allow_missing, force_init=force_init)
        self.init_optimizer(optimizer_params=optimizer_params)

        _arg_params, _aux_params = self.backbone_module.get_params()
        _arg_params_rank_0 = self.broadcast_parameters(_arg_params)
        _aux_params_rank_0 = self.broadcast_parameters(_aux_params)
        self.backbone_module.set_params(_arg_params_rank_0, _aux_params_rank_0)
        data_end_id = 0
        ################################################################################
        # training loop
        ################################################################################
        num_epoch_list = [0] * self.head_num
        for epoch in range(begin_epoch, num_epoch):
            nbatch = 0
            end_of_batch = False
            data_iter_list = []
            for i in range(self.head_num):
                train_data_list[i].reset()
                data_iter_list.append(iter(train_data_list[i]))
            next_data_batch_list = []
            for i in range(self.head_num):
                next_data_batch_list.append(next(data_iter_list[i]))
            while not end_of_batch:
                data_batch_list = next_data_batch_list
                data_batch = self.combine(data_batch_list)

                self.forward_backward(data_batch)
                self.update()
                assert not isinstance(data_batch, list)

                for i in range(self.head_num):
                    try:
                        next_data_batch_list[i] = next(data_iter_list[i])
                        self.prepare(next_data_batch_list[i], sparse_row_id_fn=None)
                    except StopIteration:
                        num_epoch_list[i] += 1
                        data_end_id += 1
                        if data_end_id != self.head_num:
                            train_data_list[i].reset()
                            data_iter_list[i] = iter(train_data_list[i])
                            next_data_batch_list[i] = next(data_iter_list[i])
                            logging.info('reset dataset_%d' % i)

                if batch_end_callback is not None:
                    batch_end_params = self.batch_end_param(
                        loss_list=self.loss_cache,
                        epoch=epoch,
                        num_update=self.num_update,
                        num_epoch_list=num_epoch_list
                    )
                    batch_end_callback(batch_end_params)

                nbatch += 1

    def get_params(self):
        _g, _x = self.backbone_module.get_params()
        g = _g.copy()
        x = _x.copy()
        # _g, _x = self.bn_module.get_params()
        # ag = _g.copy()
        # ax = _x.copy()
        # g.update(ag)
        # x.update(ax)
        return g, x

    def get_export_params(self):
        assert self.binded and self.params_initialized
        _g, _x = self.backbone_module.get_params()
        g = _g.copy()
        x = _x.copy()
        # _g, _x = self.bn_module.get_params()
        # ag = _g.copy()
        # ax = _x.copy()
        # g.update(ag)
        # x.update(ax)
        return g, x

    def get_ndarray2(self, context, name, arr):
        key = "%s_%s" % (name, context)
        if key not in self.nd_cache:
            v = nd.zeros(shape=arr.shape, ctx=context, dtype=arr.dtype)
            self.nd_cache[key] = v
        else:
            v = self.nd_cache[key]
        arr.copyto(v)
        return v

    def get_ndarray(self, context, name, shape, dtype='float32'):
        key = "%s_%s" % (name, context)
        if key not in self.nd_cache:
            v = nd.zeros(shape=shape, ctx=context, dtype=dtype)
            self.nd_cache[key] = v
        else:
            v = self.nd_cache[key]
        return v

    def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
                    allow_missing=False, force_init=False, allow_extra=False):
        assert self.binded
        # backbone
        self.backbone_module.init_params(
            initializer=initializer, arg_params=arg_params,
            aux_params=aux_params, allow_missing=allow_missing,
            force_init=force_init, allow_extra=allow_extra)

        self.backbone_module.init_params(
            initializer=initializer, arg_params=None,
            aux_params=None, allow_missing=allow_missing,
            force_init=force_init, allow_extra=allow_extra)

        # self.bn_module.init_params(
        #     initializer=initializer, arg_params=arg_params,
        #     aux_params=aux_params, allow_missing=allow_missing,
        #     force_init=force_init, allow_extra=allow_extra)
        self.params_initialized = True

    def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True,
                   allow_extra=False):
        self.init_params(
            initializer=None, arg_params=arg_params, aux_params=aux_params,
            allow_missing=allow_missing, force_init=force_init,
            allow_extra=allow_extra)

    def save_params(self, fname):
        arg_params, aux_params = self.get_params()
        save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()}
        save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()})
        ndarray.save(fname, save_dict)

    def load_params(self, fname):
        save_dict = ndarray.load(fname)
        arg_params = {}
        aux_params = {}
        for k, value in save_dict.items():
            arg_type, name = k.split(':', 1)
            if arg_type == 'arg':
                arg_params[name] = value
            elif arg_type == 'aux':
                aux_params[name] = value
            else:
                raise ValueError("Invalid param file " + fname)
        self.set_params(arg_params, aux_params)

    def get_states(self, merge_multi_context=True):
        raise NotImplementedError

    def set_states(self, states=None, value=None):
        raise NotImplementedError

    def prepare(self, data_batch, sparse_row_id_fn=None):
        if sparse_row_id_fn is not None:
            warnings.warn(UserWarning("sparse_row_id_fn is not invoked for BaseModule."))

    def allgather(self, tensor, name, shape, dtype, context):
        assert isinstance(tensor, nd.NDArray),          type(tensor)
        assert isinstance(name, str),                   type(name)
        assert isinstance(shape, tuple),                type(shape)
        assert isinstance(dtype, str),                  type(dtype)
        assert isinstance(context, mx.context.Context), type(context)
        """
        Implement in-place AllGather using AllReduce
        """
        total_tensor = self.get_ndarray(
            context=context, name=name, shape=shape, dtype=dtype)
        total_tensor[:] = 0                                      # reset array before all-reduce is very important
        total_tensor[self.rank * self.batch_size:
                     self.rank * self.batch_size + self.batch_size] = tensor
        hvd.allreduce_(total_tensor, average=False)              # all-reduce in-place
        return total_tensor

    # pylint: enable=unused-argument
    def forward(self, data_batch, is_train=None):
        assert self.binded and self.params_initialized
        self.backbone_module.forward(data_batch, is_train=is_train)
        if is_train:
            self.num_update += 1
            fc1 = self.backbone_module.get_outputs()[0]
            label = data_batch.label[0]

            total_features = self.allgather(
                tensor=fc1,
                name='total_feature',
                shape=(self.batch_size * self.size, self.embedding_size),
                dtype='float32',
                context=self.gpu
            )
            total_labels = self.allgather(
                tensor=label,
                name='total_label',
                shape=(self.batch_size * self.size,),
                dtype='int32',
                context=self.cpu
            )

            # self.bn_module.forward(mx.io.DataBatch([total_features],  []), is_train=True)
            # total_features = self.bn_module.get_outputs(merge_multi_context=True)[0]
            return total_features, total_labels
        else:
            return None
            # raise ValueError

    def backward_all(self, total_feature, total_label, ):
        # get memory bank learning rate
        self.memory_lr = self.memory_optimizer.lr_scheduler(self.num_update)

        # reverse shuffle bn
        total_feature = total_feature.reshape(-1, self.embedding_size * self.head_num)
        # global_label
        total_label = total_label.reshape(-1, self.head_num)
        #
        self.grad_cache = self.get_ndarray(self.gpu, 'grad_cache', total_feature.shape)
        self.loss_cache = self.get_ndarray(self.gpu, 'loss_cache', [self.head_num])

        self.grad_cache[:] = 0
        self.loss_cache[:] = 0

        for head_id in range(self.head_num):
            _fc1_one_head = total_feature[
                            :,
                            head_id * self.embedding_size:
                            head_id * self.embedding_size + self.embedding_size
                            ]
            _label_one_head = total_label[:, head_id]

            grad, loss = self.backward(head_id, _fc1_one_head, _label_one_head)
            self.grad_cache[
                :,
                head_id * self.embedding_size:
                head_id * self.embedding_size + self.embedding_size
            ] = grad
            self.loss_cache[head_id] = loss

        total_feature_grad = self.grad_cache.reshape(-1, self.embedding_size)
        total_feature_grad = hvd.allreduce(total_feature_grad, average=False)

        # self.bn_module.backward(out_grads=[total_feature_grad / self.backbone_grad_rescale])
        # bn_input_grad = self.bn_module.get_input_grads()[0]

        fc1_grad = total_feature_grad[
            self.batch_size * self.rank:
            self.batch_size * self.rank + self.batch_size
        ]
        self.backbone_module.backward(out_grads=[fc1_grad])

    def backward(self, head_id, fc1, label):

        memory_bank = self.memory_bank_list[head_id]
        this_rank_classes = int(memory_bank.num_sample)
        local_index, unique_sorted_global_label = memory_bank.sample(label)

        # Get local index
        _mapping_dict = {}
        local_sampled_class = local_index + self.rank * memory_bank.num_local
        global_label_set = set(unique_sorted_global_label)
        for idx, absolute_label in enumerate(local_sampled_class):
            if absolute_label in global_label_set:
                _mapping_dict[absolute_label] = idx + self.rank * memory_bank.num_sample

        label_list = list(label.asnumpy())
        mapping_label = []
        for i in range(len(label_list)):
            absolute_label = label_list[i]
            if absolute_label in _mapping_dict.keys():
                mapping_label.append(_mapping_dict[absolute_label])
            else:
                mapping_label.append(-1)

        mapping_label = nd.array(mapping_label, dtype=np.int32)

        # Get weight
        local_index = nd.array(local_index)
        local_index = self.get_ndarray2(self.gpu, "local_index_%d" % head_id, local_index)
        sample_weight, sample_weight_mom = memory_bank.get(local_index)

        # Sync to gpu
        if memory_bank.gpu:
            _data       = self.get_ndarray2(self.gpu, "data_%d_%d"       % (self.rank, head_id), fc1)
            _weight     = self.get_ndarray2(self.gpu, 'weight_%d_%d'     % (self.rank, head_id), sample_weight)
            _weight_mom = self.get_ndarray2(self.gpu, 'weight_mom_%d_%d' % (self.rank, head_id), sample_weight_mom)
        else:
            _data       = self.get_ndarray2(self.gpu, "data_%d_%d"       % (self.rank, head_id), fc1)
            _weight     = self.get_ndarray2(self.gpu, 'weight_%d_%d'     % (self.rank, head_id), sample_weight)
            _weight_mom = self.get_ndarray2(self.gpu, 'weight_mom_%d_%d' % (self.rank, head_id), sample_weight_mom)

        # Attach grad
        _data.attach_grad()
        _weight.attach_grad()

        # Convert label
        _label = self.get_ndarray2(self.gpu, 'mapping_label_%d_%d' % (self.rank, head_id), mapping_label)
        _label = _label - int(self.rank * memory_bank.num_sample)
        _fc7, _one_hot = self.fc7_model.forward(_data, _weight, mapping_label=_label, depth=this_rank_classes)

        # Sync max
        max_fc7 = nd.max(_fc7, axis=1, keepdims=True)
        max_fc7 = nd.reshape(max_fc7, -1)

        total_max_fc7 = self.get_ndarray(
            context=self.gpu, name='total_max_fc7_%d' % head_id,
            shape=(max_fc7.shape[0], self.size), dtype='float32')
        total_max_fc7[:] = 0
        total_max_fc7[:, self.rank] = max_fc7
        hvd.allreduce_(total_max_fc7, average=False)

        global_max_fc7 = self.get_ndarray(
            context=self.gpu, name='global_max_fc7_%d' % head_id,
            shape=(max_fc7.shape[0], 1), dtype='float32')
        nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7)

        # Calculate prob
        _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7)
        _fc7_grad = nd.exp(_fc7_grad)

        # Calculate sum
        sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True)
        global_sum_fc7 = hvd.allreduce(sum_fc7, average=False)

        # Calculate grad
        _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7)

        # Calculate loss
        tmp = _fc7_grad * _one_hot
        tmp = nd.sum(tmp, axis=1, keepdims=True)
        tmp = self.get_ndarray2(self.gpu, 'ctx_loss_%d' % head_id, tmp)
        tmp = hvd.allreduce(tmp, average=False)
        global_loss = -nd.mean(nd.log(tmp + 1e-30))

        _fc7_grad = _fc7_grad - _one_hot

        # Backward
        _fc7.backward(out_grad=_fc7_grad)

        # Update center
        _weight_grad = _weight.grad
        self.memory_optimizer.update(weight=_weight, grad=_weight_grad, state=_weight_mom,
                                     learning_rate=self.memory_lr * self.memory_lr_scale_list[head_id])
        if memory_bank.gpu:
            memory_bank.set(
                index=local_index,
                updated_weight=_weight,
                updated_weight_mom=_weight_mom)
        else:
            memory_bank.set(
                index=local_index,
                updated_weight     = self.get_ndarray2(mx.cpu(), "cpu_weight_%d_%d"     % (self.rank, head_id), _weight),
                updated_weight_mom = self.get_ndarray2(mx.cpu(), "cpu_weight_mom_%d_%d" % (self.rank, head_id), _weight_mom))
        return _data.grad, global_loss

    def get_outputs(self, merge_multi_context=True):
        return self.backbone_module.get_outputs(merge_multi_context=merge_multi_context)

    def update(self):
        self.backbone_module.update()
        # self.bn_module.update()
        mx.nd.waitall()

    def bind(self, data_shapes_list=None, label_shapes_list=None, for_training=True,
             inputs_need_grad=False):
        assert data_shapes_list is not None and label_shapes_list is not None
        if self.binded:
            self.logger.warning('Already binded, ignoring bind()')
            return
        data_name = data_shapes_list[0][0][0]
        data_shapes = data_shapes_list[0][0][1]
        label_name = label_shapes_list[0][0][0]
        label_shapes = label_shapes_list[0][0][1]

        self.for_training = for_training
        self.inputs_need_grad = inputs_need_grad
        self.binded = True
        _backbone_data_shapes = [(data_name, (self.batch_size,) + data_shapes[1:])]
        _backbone_label_shapes = [(label_name, (self.batch_size,) + label_shapes[1:])]

        _bn_data_shapes = [(data_name, (self.batch_size * self.size, self.embedding_size))]
        self.backbone_module.bind(
            data_shapes=_backbone_data_shapes,
            label_shapes=_backbone_label_shapes,
            for_training=for_training,
            inputs_need_grad=inputs_need_grad)
        # self.bn_module.bind(
        #     data_shapes=_bn_data_shapes,
        #     for_training=for_training,
        #     inputs_need_grad=True
        # )

    def init_optimizer(self, optimizer_params, force_init=False):
        assert self.binded and self.params_initialized
        if self.optimizer_initialized and not force_init:
            self.logger.warning('optimizer already initialized, ignoring.')
            return
        # backbone
        # optimizer_backbone = DistributedOptimizer(LARS(**optimizer_params))
        # optimizer_bn       = DistributedOptimizer(LARS(**optimizer_params), prefix='bn_')

        optimizer_backbone = DistributedOptimizer(SGD(**optimizer_params))
        self.backbone_module.init_optimizer(
            'local', optimizer_backbone, force_init=force_init)
        # optimizer_bn = DistributedOptimizer(SGD(**optimizer_params), prefix='bn_')
        # self.bn_module.init_optimizer(
        #     'local', optimizer_bn,       force_init=force_init)
        self.optimizer_initialized = True