class ParallModule(BaseModule): def __init__(self, symbol, data_names, label_names, logger=logging, context=ctx.cpu(), work_load_list=None, asymbol = None, args = None): super(ParallModule, self).__init__(logger=logger) self._symbol = symbol self._asymbol = asymbol self._data_names = data_names self._label_names = label_names self._context = context self._work_load_list = work_load_list self._num_classes = config.num_classes self._batch_size = args.batch_size self._verbose = args.verbose self._emb_size = config.emb_size self._local_class_start = args.local_class_start self._iter = 0 self._curr_module = None self._num_workers = config.num_workers self._num_ctx = len(self._context) self._ctx_num_classes = args.ctx_num_classes self._nd_cache = {} self._ctx_cpu = mx.cpu() self._ctx_single_gpu = self._context[-1] self._fixed_param_names = None self._curr_module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names) self._arcface_modules = [] self._ctx_class_start = [] for i in range(len(self._context)): args._ctxid = i _module = Module(self._asymbol(args), self._data_names, self._label_names, logger=self.logger, context=mx.gpu(i), work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names) self._arcface_modules.append(_module) _c = args.local_class_start + i*args.ctx_num_classes self._ctx_class_start.append(_c) self._usekv = False if self._usekv: self._distkv = mx.kvstore.create('dist_sync') self._kvinit = {} def _reset_bind(self): self.binded = False self._curr_module = None @property def data_names(self): return self._data_names @property def output_names(self): return self._symbol.list_outputs() @property def data_shapes(self): assert self.binded return self._curr_module.data_shapes @property def label_shapes(self): assert self.binded return self._curr_module.label_shapes @property def output_shapes(self): assert self.binded return self._curr_module.output_shapes def get_export_params(self): assert self.binded and self.params_initialized _g, _x = self._curr_module.get_params() g = _g.copy() x = _x.copy() return g, x def get_params(self): assert self.binded and self.params_initialized _g, _x = self._curr_module.get_params() g = _g.copy() x = _x.copy() for _module in self._arcface_modules: _g, _x = _module.get_params() ag = _g.copy() ax = _x.copy() g.update(ag) x.update(ax) return g, x def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True, allow_extra=False): g = arg_params x = aux_params #ag = {} #ax = {} rk = [] for k in g: v = g[k] if k.startswith('fc7'): p1 = k.find('_') p2 = k.rfind('_') _ctxid = int(k[p1+1:p2]) self._arcface_modules[_ctxid].set_params({k:v}, {}) rk.append(k) for k in rk: del g[k] self._curr_module.set_params(g, x) #self._arcface_module.set_params(ag, ax) def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_init=False, allow_extra=False): if self.params_initialized and not force_init: return assert self.binded, 'call bind before initializing the parameters' #TODO init the same weights with all work nodes self._curr_module.init_params(initializer=initializer, arg_params=None, aux_params=None, allow_missing=allow_missing, force_init=force_init, allow_extra=allow_extra) for _module in self._arcface_modules: #_initializer = initializer _initializer = mx.init.Normal(0.01) _module.init_params(initializer=_initializer, arg_params=None, aux_params=None, allow_missing=allow_missing, force_init=force_init, allow_extra=allow_extra) self.params_initialized = True def bind(self, data_shapes, label_shapes=None, for_training=True, inputs_need_grad=False, force_rebind=False, shared_module=None): print('in_bind', self.params_initialized, data_shapes, label_shapes) if self.params_initialized: arg_params, aux_params = self.get_params() # force rebinding is typically used when one want to switch from # training to prediction phase. if force_rebind: self._reset_bind() if self.binded: self.logger.warning('Already binded, ignoring bind()') return assert shared_module is None, 'shared_module for MutableModule is not supported' self.for_training = for_training self.inputs_need_grad = inputs_need_grad self.binded = True self._curr_module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind=False, shared_module=None) _data_shape = data_shapes[0][1] print('_data_shape', _data_shape, label_shapes) for _module in self._arcface_modules: _module.bind([('data', (_data_shape[0]*self._num_workers, self._emb_size))], [('softmax_label', (_data_shape[0]*self._num_workers,))], for_training, True, force_rebind=False, shared_module=None) if self.params_initialized: self.set_params(arg_params, aux_params) def init_optimizer(self, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), force_init=False): assert self.binded and self.params_initialized if self.optimizer_initialized and not force_init: self.logger.warning('optimizer already initialized, ignoring.') return self._curr_module.init_optimizer(kvstore, optimizer, optimizer_params, force_init=force_init) for _module in self._arcface_modules: _module.init_optimizer(kvstore, optimizer, optimizer_params, force_init=force_init) self.optimizer_initialized = True def kv_push(self, key, value): #if value.context!=mx.cpu(): # value = value.as_in_context(mx.cpu()) if not key in self._kvinit: self._distkv.init(key, nd.zeros_like(value)) self._kvinit[key] = 1 self._distkv.push(key, value) #get fc1 and partial fc7 def forward(self, data_batch, is_train=None): #g,x = self.get_params() #print('{fc7_weight[0][0]}', self._iter, g['fc7_0_weight'].asnumpy()[0][0]) #print('{pre_fc1_weight[0][0]}', self._iter, g['pre_fc1_weight'].asnumpy()[0][0]) assert self.binded and self.params_initialized self._curr_module.forward(data_batch, is_train=is_train) if is_train: self._iter+=1 fc1, label = self._curr_module.get_outputs(merge_multi_context=True) global_fc1 = fc1 self.global_label = label.as_in_context(self._ctx_cpu) for i, _module in enumerate(self._arcface_modules): _label = self.global_label - self._ctx_class_start[i] db_global_fc1 = io.DataBatch([global_fc1], [_label]) _module.forward(db_global_fc1) #fc7 with margin #print('forward end') def get_ndarray(self, context, name, shape): key = "%s_%s"%(name, context) #print(key) if not key in self._nd_cache: v = nd.zeros( shape=shape, ctx = context) self._nd_cache[key] = v else: v = self._nd_cache[key] return v def get_ndarray2(self, context, name, arr): key = "%s_%s"%(name, context) #print(key) if not key in self._nd_cache: v = nd.zeros( shape=arr.shape, ctx = context) self._nd_cache[key] = v else: v = self._nd_cache[key] arr.copyto(v) return v def backward(self, out_grads=None): #print('in backward') assert self.binded and self.params_initialized #tmp_ctx = self._ctx_cpu tmp_ctx = self._ctx_single_gpu fc7_outs = [] ctx_fc7_max = self.get_ndarray(tmp_ctx, 'ctx_fc7_max', (self._batch_size, len(self._context))) #local_fc7_max = nd.zeros( (self.global_label.shape[0],1), ctx=mx.cpu()) arcface_module_outputs = [] for i, _module in enumerate(self._arcface_modules): #_fc7 = _module.get_outputs(merge_multi_context=True)[0] out = _module.get_outputs(merge_multi_context=True) #print(out[0].shape) #print(out[1].shape) arcface_module_outputs.append(out) _fc7 = out[0] fc7_outs.append(_fc7) _fc7_max = nd.max(_fc7, axis=1).as_in_context(tmp_ctx) ctx_fc7_max[:,i] = _fc7_max local_fc7_max = self.get_ndarray(tmp_ctx, 'local_fc7_max', (self._batch_size, 1)) nd.max(ctx_fc7_max, axis=1, keepdims=True, out=local_fc7_max) global_fc7_max = local_fc7_max #local_fc7_sum = None local_fc7_sum = self.get_ndarray(tmp_ctx, 'local_fc7_sum', (self._batch_size,1)) local_fc7_sum[:,:] = 0.0 for i, _module in enumerate(self._arcface_modules): _max = self.get_ndarray2(fc7_outs[i].context, 'fc7_max', global_fc7_max) fc7_outs[i] = nd.broadcast_sub(fc7_outs[i], _max) fc7_outs[i] = nd.exp(fc7_outs[i]) _sum = nd.sum(fc7_outs[i], axis=1, keepdims=True).as_in_context(tmp_ctx) local_fc7_sum += _sum global_fc7_sum = local_fc7_sum if self._iter%self._verbose==0: #_ctx = self._context[-1] _ctx = self._ctx_cpu _probs = [] for i, _module in enumerate(self._arcface_modules): _prob = self.get_ndarray2(_ctx, '_fc7_prob_%d'%i, fc7_outs[i]) _probs.append(_prob) fc7_prob = self.get_ndarray(_ctx, 'test_fc7_prob', (self._batch_size, self._ctx_num_classes*len(self._context))) nd.concat(*_probs, dim=1, out=fc7_prob) fc7_pred = nd.argmax(fc7_prob, axis=1) local_label = self.global_label - self._local_class_start #local_label = self.get_ndarray2(_ctx, 'test_label', local_label) _pred = nd.equal(fc7_pred, local_label) print('{fc7_acc}', self._iter, nd.mean(_pred).asnumpy()[0]) #local_fc1_grad = [] #fc1_grad_ctx = self._ctx_cpu fc1_grad_ctx = self._ctx_single_gpu local_fc1_grad = self.get_ndarray(fc1_grad_ctx, 'local_fc1_grad', (self._batch_size,self._emb_size)) local_fc1_grad[:,:] = 0.0 total_eloss = [] celoss_verbose = 1000 if self._iter%celoss_verbose==0: fc7_celoss = self.get_ndarray(tmp_ctx, 'test_fc7_celoss', (self._batch_size,)) fc7_celoss[:] = 0.0 for i, _module in enumerate(self._arcface_modules): _sum = self.get_ndarray2(fc7_outs[i].context, 'fc7_sum', global_fc7_sum) fc7_outs[i] = nd.broadcast_div(fc7_outs[i], _sum) a = i*self._ctx_num_classes b = (i+1)*self._ctx_num_classes _label = self.global_label - self._ctx_class_start[i] _label = self.get_ndarray2(fc7_outs[i].context, 'label', _label) onehot_label = self.get_ndarray(fc7_outs[i].context, 'label_onehot', (self._batch_size, self._ctx_num_classes)) nd.one_hot(_label, depth=self._ctx_num_classes, on_value = 1.0, off_value = 0.0, out=onehot_label) #print(fc7_outs[i].shape, onehot_label.shape) if self._iter%celoss_verbose==0: _ce_loss = fc7_outs[i] * onehot_label _ce_loss = nd.sum(_ce_loss, axis=1) fc7_celoss += _ce_loss.as_in_context(tmp_ctx) fc7_outs[i] -= onehot_label out = arcface_module_outputs[i] out_grads = [fc7_outs[i]] for j in range(1, len(out)): eloss = out[j] #print('eloss%d:'%j, eloss.shape) #print(out_grads[0].shape) #egrad_shape = (out_grads[0].shape[0], eloss.shape[0]) egrad_shape = eloss.shape egrad = self.get_ndarray(fc7_outs[i].context, 'egrad%d'%j, egrad_shape) #egrad[:][:] = 1.0/egrad_shape[0] egrad[:][:] = 1.0 out_grads.append(egrad) if self._iter%self._verbose==0: total_eloss.append(np.mean(eloss.asnumpy())) _module.backward(out_grads = out_grads) #ctx_fc1_grad = _module.get_input_grads()[0].as_in_context(mx.cpu()) ctx_fc1_grad = self.get_ndarray2(fc1_grad_ctx, 'ctx_fc1_grad_%d'%i, _module.get_input_grads()[0]) local_fc1_grad += ctx_fc1_grad if self._iter%self._verbose==0 and len(total_eloss)>0: print('{eloss}', self._iter, np.mean(total_eloss)) #if self._iter%self._verbose==0: if self._iter%celoss_verbose==0: ce_loss = nd.log(fc7_celoss) * -1.0 ce_loss = nd.mean(ce_loss) print('CELOSS,%d,%f'% (self._iter, ce_loss.asnumpy())) global_fc1_grad = local_fc1_grad self._curr_module.backward(out_grads = [global_fc1_grad]) def update(self): assert self.binded and self.params_initialized and self.optimizer_initialized self._curr_module.update() for i, _module in enumerate(self._arcface_modules): _module.update() mx.nd.waitall() def get_outputs(self, merge_multi_context=True): assert self.binded and self.params_initialized return self._curr_module.get_outputs(merge_multi_context=merge_multi_context) #return self._arcface_module.get_outputs(merge_multi_context=merge_multi_context) def get_input_grads(self, merge_multi_context=True): assert self.binded and self.params_initialized and self.inputs_need_grad return self._curr_module.get_input_grads(merge_multi_context=merge_multi_context) def update_metric(self, eval_metric, labels): assert self.binded and self.params_initialized #self._curr_module.update_metric(eval_metric, labels) #label = labels[0] #print(label.shape) #self._arcface_module.update_metric(eval_metric, labels) def install_monitor(self, mon): """ Install monitor on all executors """ assert self.binded self._curr_module.install_monitor(mon) def forward_backward(self, data_batch): """A convenient function that calls both ``forward`` and ``backward``.""" self.forward(data_batch, is_train=True) # get fc1 and partial fc7 self.backward() def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None): """Trains the module parameters. Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see a end-to-end use-case. Parameters ---------- train_data : DataIter Train DataIter. eval_data : DataIter If not ``None``, will be used as validation set and the performance after each epoch will be evaluated. eval_metric : str or EvalMetric Defaults to 'accuracy'. The performance measure used to display during training. Other possible predefined metrics are: 'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'. epoch_end_callback : function or list of functions Each callback will be called with the current `epoch`, `symbol`, `arg_params` and `aux_params`. batch_end_callback : function or list of function Each callback will be called with a `BatchEndParam`. kvstore : str or KVStore Defaults to 'local'. optimizer : str or Optimizer Defaults to 'sgd'. optimizer_params : dict Defaults to ``(('learning_rate', 0.01),)``. The parameters for the optimizer constructor. The default value is not a dict, just to avoid pylint warning on dangerous default values. eval_end_callback : function or list of function These will be called at the end of each full evaluation, with the metrics over the entire evaluation set. eval_batch_end_callback : function or list of function These will be called at the end of each mini-batch during evaluation. initializer : Initializer The initializer is called to initialize the module parameters when they are not already initialized. arg_params : dict Defaults to ``None``, if not ``None``, should be existing parameters from a trained model or loaded from a checkpoint (previously saved model). In this case, the value here will be used to initialize the module parameters, unless they are already initialized by the user via a call to `init_params` or `fit`. `arg_params` has a higher priority than `initializer`. aux_params : dict Defaults to ``None``. Similar to `arg_params`, except for auxiliary states. allow_missing : bool Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params` and `aux_params` are not ``None``. If this is ``True``, then the missing parameters will be initialized via the `initializer`. force_rebind : bool Defaults to ``False``. Whether to force rebinding the executors if already bound. force_init : bool Defaults to ``False``. Indicates whether to force initialization even if the parameters are already initialized. begin_epoch : int Defaults to 0. Indicates the starting epoch. Usually, if resumed from a checkpoint saved at a previous training phase at epoch N, then this value should be N+1. num_epoch : int Number of epochs for training. sparse_row_id_fn : A callback function The function takes `data_batch` as an input and returns a dict of str -> NDArray. The resulting dict is used for pulling row_sparse parameters from the kvstore, where the str key is the name of the param, and the value is the row id of the param to pull. Examples -------- >>> # An example of using fit for training. >>> # Assume training dataIter and validation dataIter are ready >>> # Assume loading a previously checkpointed model >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd', ... optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, ... arg_params=arg_params, aux_params=aux_params, ... eval_metric='acc', num_epoch=10, begin_epoch=3) """ assert num_epoch is not None, 'please specify number of epochs' assert arg_params is None and aux_params is None self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) epoch_eval_metric = copy.deepcopy(eval_metric) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() epoch_eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() assert not isinstance(data_batch, list) #if isinstance(data_batch, list): # #print('XXX') # self.update_metric(eval_metric, # [db.label for db in data_batch], # pre_sliced=True) # self.update_metric(epoch_eval_metric, # [db.label for db in data_batch], # pre_sliced=True) #else: # #print('before update metric') # self.update_metric(eval_metric, data_batch.label) # self.update_metric(epoch_eval_metric, data_batch.label) #labels = data_batch.label #labels = [self.global_label] #self.update_metric(eval_metric, labels) #self.update_metric(epoch_eval_metric, labels) try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn) except StopIteration: end_of_batch = True if monitor is not None: monitor.toc_print() #if end_of_batch: # eval_name_vals = epoch_eval_metric.get_name_value() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=None, locals=locals()) batch_end_callback(batch_end_params) #for callback in _as_list(batch_end_callback): # callback(batch_end_params) nbatch += 1 # one epoch of training is finished #for name, val in eval_name_vals: # self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
class MemoryModuleGlint(object): def __init__(self, symbol, bn_symbol, batch_size, fc7_model, size, rank, local_rank, memory_bank_list, memory_optimizer, backbone_grad_rescale, memory_lr_scale_list, embedding_size=512, head_num=1, logger=logging, ): # configure horovod self.memory_lr_scale_list = memory_lr_scale_list self.size = size self.rank = rank self.local_rank = local_rank self.gpu = mx.gpu(self.local_rank) self.cpu = mx.cpu() # `device_id` is not needed for CPU. self.nd_cache = {} self.embedding_size = embedding_size self.batch_size = batch_size self.num_update = 0 self.batch_end_param = namedtuple( 'batch_end_param', ['loss_list', 'num_epoch_list', 'epoch', 'num_update']) self.symbol = symbol # self.bn_symbol = bn_symbol # self.logger = logger self.backbone_module = Module(self.symbol, ['data'], ['softmax_label'], logger=self.logger, context=self.gpu) # self.bn_module = Module(self.bn_symbol, ['data'], None, logger=self.logger, context=self.gpu) self.head_num = head_num self.memory_bank_list = memory_bank_list self.memory_optimizer = memory_optimizer self.memory_lr = None self.loss_cache = None self.grad_cache = None assert isinstance(self.memory_bank_list, list) # init self.fc7_model = fc7_model # fp16 self.backbone_grad_rescale = backbone_grad_rescale self.binded = False self.for_training = False self.inputs_need_grad = False self.params_initialized = False self.optimizer_initialized = False self._total_exec_bytes = 0 self.global_label = None def forward_backward(self, data_batch): total_feature, total_label = self.forward(data_batch, is_train=True) self.backward_all(total_feature, total_label) @staticmethod def sync_aux_params(params): pass @staticmethod def broadcast_parameters(params): rank_0_dict = {} # Run broadcasts. for key, tensor in params.items(): rank_0_dict[key] = hvd.broadcast(tensor, 0, key) return rank_0_dict @staticmethod def combine(data_batches): assert isinstance(data_batches, list), "data_batches must be a list." length = len(data_batches) total_data = [data_batches[0].data[0].reshape(0, -1)] total_label = [data_batches[0].label[0].reshape(0, 1)] data_shape = data_batches[0].data[0].shape if length > 1: for i in range(1, length): assert data_batches[i].data[0].shape[0] == data_batches[0].data[0].shape[0] total_data.append(data_batches[i].data[0].reshape(0, -1)) total_label.append(data_batches[i].label[0].reshape(0, 1)) # shuffle total_data = mx.nd.concat(*total_data, dim=1) total_data = total_data.reshape(-1, data_shape[1], data_shape[2], data_shape[3]) total_label = mx.nd.concat(*total_label, dim=1) total_label = total_label.reshape(-1) return mx.io.DataBatch([total_data], [total_label]) def fit(self, train_data_list, optimizer_params, batch_end_callback=None, kvstore='local', initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None): assert num_epoch is not None, 'please specify number of epochs' assert arg_params is None and aux_params is None provide_data_list = [] provide_label_list = [] for td in train_data_list: provide_data_list.append(td.provide_data) provide_label_list.append(td.provide_label) self.bind(data_shapes_list=provide_data_list, label_shapes_list=provide_label_list, for_training=True) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(optimizer_params=optimizer_params) _arg_params, _aux_params = self.backbone_module.get_params() _arg_params_rank_0 = self.broadcast_parameters(_arg_params) _aux_params_rank_0 = self.broadcast_parameters(_aux_params) self.backbone_module.set_params(_arg_params_rank_0, _aux_params_rank_0) data_end_id = 0 ################################################################################ # training loop ################################################################################ num_epoch_list = [0] * self.head_num for epoch in range(begin_epoch, num_epoch): nbatch = 0 end_of_batch = False data_iter_list = [] for i in range(self.head_num): train_data_list[i].reset() data_iter_list.append(iter(train_data_list[i])) next_data_batch_list = [] for i in range(self.head_num): next_data_batch_list.append(next(data_iter_list[i])) while not end_of_batch: data_batch_list = next_data_batch_list data_batch = self.combine(data_batch_list) self.forward_backward(data_batch) self.update() assert not isinstance(data_batch, list) for i in range(self.head_num): try: next_data_batch_list[i] = next(data_iter_list[i]) self.prepare(next_data_batch_list[i], sparse_row_id_fn=None) except StopIteration: num_epoch_list[i] += 1 data_end_id += 1 if data_end_id != self.head_num: train_data_list[i].reset() data_iter_list[i] = iter(train_data_list[i]) next_data_batch_list[i] = next(data_iter_list[i]) logging.info('reset dataset_%d' % i) if batch_end_callback is not None: batch_end_params = self.batch_end_param( loss_list=self.loss_cache, epoch=epoch, num_update=self.num_update, num_epoch_list=num_epoch_list ) batch_end_callback(batch_end_params) nbatch += 1 def get_params(self): _g, _x = self.backbone_module.get_params() g = _g.copy() x = _x.copy() # _g, _x = self.bn_module.get_params() # ag = _g.copy() # ax = _x.copy() # g.update(ag) # x.update(ax) return g, x def get_export_params(self): assert self.binded and self.params_initialized _g, _x = self.backbone_module.get_params() g = _g.copy() x = _x.copy() # _g, _x = self.bn_module.get_params() # ag = _g.copy() # ax = _x.copy() # g.update(ag) # x.update(ax) return g, x def get_ndarray2(self, context, name, arr): key = "%s_%s" % (name, context) if key not in self.nd_cache: v = nd.zeros(shape=arr.shape, ctx=context, dtype=arr.dtype) self.nd_cache[key] = v else: v = self.nd_cache[key] arr.copyto(v) return v def get_ndarray(self, context, name, shape, dtype='float32'): key = "%s_%s" % (name, context) if key not in self.nd_cache: v = nd.zeros(shape=shape, ctx=context, dtype=dtype) self.nd_cache[key] = v else: v = self.nd_cache[key] return v def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_init=False, allow_extra=False): assert self.binded # backbone self.backbone_module.init_params( initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init, allow_extra=allow_extra) self.backbone_module.init_params( initializer=initializer, arg_params=None, aux_params=None, allow_missing=allow_missing, force_init=force_init, allow_extra=allow_extra) # self.bn_module.init_params( # initializer=initializer, arg_params=arg_params, # aux_params=aux_params, allow_missing=allow_missing, # force_init=force_init, allow_extra=allow_extra) self.params_initialized = True def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True, allow_extra=False): self.init_params( initializer=None, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init, allow_extra=allow_extra) def save_params(self, fname): arg_params, aux_params = self.get_params() save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()} save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) ndarray.save(fname, save_dict) def load_params(self, fname): save_dict = ndarray.load(fname) arg_params = {} aux_params = {} for k, value in save_dict.items(): arg_type, name = k.split(':', 1) if arg_type == 'arg': arg_params[name] = value elif arg_type == 'aux': aux_params[name] = value else: raise ValueError("Invalid param file " + fname) self.set_params(arg_params, aux_params) def get_states(self, merge_multi_context=True): raise NotImplementedError def set_states(self, states=None, value=None): raise NotImplementedError def prepare(self, data_batch, sparse_row_id_fn=None): if sparse_row_id_fn is not None: warnings.warn(UserWarning("sparse_row_id_fn is not invoked for BaseModule.")) def allgather(self, tensor, name, shape, dtype, context): assert isinstance(tensor, nd.NDArray), type(tensor) assert isinstance(name, str), type(name) assert isinstance(shape, tuple), type(shape) assert isinstance(dtype, str), type(dtype) assert isinstance(context, mx.context.Context), type(context) """ Implement in-place AllGather using AllReduce """ total_tensor = self.get_ndarray( context=context, name=name, shape=shape, dtype=dtype) total_tensor[:] = 0 # reset array before all-reduce is very important total_tensor[self.rank * self.batch_size: self.rank * self.batch_size + self.batch_size] = tensor hvd.allreduce_(total_tensor, average=False) # all-reduce in-place return total_tensor # pylint: enable=unused-argument def forward(self, data_batch, is_train=None): assert self.binded and self.params_initialized self.backbone_module.forward(data_batch, is_train=is_train) if is_train: self.num_update += 1 fc1 = self.backbone_module.get_outputs()[0] label = data_batch.label[0] total_features = self.allgather( tensor=fc1, name='total_feature', shape=(self.batch_size * self.size, self.embedding_size), dtype='float32', context=self.gpu ) total_labels = self.allgather( tensor=label, name='total_label', shape=(self.batch_size * self.size,), dtype='int32', context=self.cpu ) # self.bn_module.forward(mx.io.DataBatch([total_features], []), is_train=True) # total_features = self.bn_module.get_outputs(merge_multi_context=True)[0] return total_features, total_labels else: return None # raise ValueError def backward_all(self, total_feature, total_label, ): # get memory bank learning rate self.memory_lr = self.memory_optimizer.lr_scheduler(self.num_update) # reverse shuffle bn total_feature = total_feature.reshape(-1, self.embedding_size * self.head_num) # global_label total_label = total_label.reshape(-1, self.head_num) # self.grad_cache = self.get_ndarray(self.gpu, 'grad_cache', total_feature.shape) self.loss_cache = self.get_ndarray(self.gpu, 'loss_cache', [self.head_num]) self.grad_cache[:] = 0 self.loss_cache[:] = 0 for head_id in range(self.head_num): _fc1_one_head = total_feature[ :, head_id * self.embedding_size: head_id * self.embedding_size + self.embedding_size ] _label_one_head = total_label[:, head_id] grad, loss = self.backward(head_id, _fc1_one_head, _label_one_head) self.grad_cache[ :, head_id * self.embedding_size: head_id * self.embedding_size + self.embedding_size ] = grad self.loss_cache[head_id] = loss total_feature_grad = self.grad_cache.reshape(-1, self.embedding_size) total_feature_grad = hvd.allreduce(total_feature_grad, average=False) # self.bn_module.backward(out_grads=[total_feature_grad / self.backbone_grad_rescale]) # bn_input_grad = self.bn_module.get_input_grads()[0] fc1_grad = total_feature_grad[ self.batch_size * self.rank: self.batch_size * self.rank + self.batch_size ] self.backbone_module.backward(out_grads=[fc1_grad]) def backward(self, head_id, fc1, label): memory_bank = self.memory_bank_list[head_id] this_rank_classes = int(memory_bank.num_sample) local_index, unique_sorted_global_label = memory_bank.sample(label) # Get local index _mapping_dict = {} local_sampled_class = local_index + self.rank * memory_bank.num_local global_label_set = set(unique_sorted_global_label) for idx, absolute_label in enumerate(local_sampled_class): if absolute_label in global_label_set: _mapping_dict[absolute_label] = idx + self.rank * memory_bank.num_sample label_list = list(label.asnumpy()) mapping_label = [] for i in range(len(label_list)): absolute_label = label_list[i] if absolute_label in _mapping_dict.keys(): mapping_label.append(_mapping_dict[absolute_label]) else: mapping_label.append(-1) mapping_label = nd.array(mapping_label, dtype=np.int32) # Get weight local_index = nd.array(local_index) local_index = self.get_ndarray2(self.gpu, "local_index_%d" % head_id, local_index) sample_weight, sample_weight_mom = memory_bank.get(local_index) # Sync to gpu if memory_bank.gpu: _data = self.get_ndarray2(self.gpu, "data_%d_%d" % (self.rank, head_id), fc1) _weight = self.get_ndarray2(self.gpu, 'weight_%d_%d' % (self.rank, head_id), sample_weight) _weight_mom = self.get_ndarray2(self.gpu, 'weight_mom_%d_%d' % (self.rank, head_id), sample_weight_mom) else: _data = self.get_ndarray2(self.gpu, "data_%d_%d" % (self.rank, head_id), fc1) _weight = self.get_ndarray2(self.gpu, 'weight_%d_%d' % (self.rank, head_id), sample_weight) _weight_mom = self.get_ndarray2(self.gpu, 'weight_mom_%d_%d' % (self.rank, head_id), sample_weight_mom) # Attach grad _data.attach_grad() _weight.attach_grad() # Convert label _label = self.get_ndarray2(self.gpu, 'mapping_label_%d_%d' % (self.rank, head_id), mapping_label) _label = _label - int(self.rank * memory_bank.num_sample) _fc7, _one_hot = self.fc7_model.forward(_data, _weight, mapping_label=_label, depth=this_rank_classes) # Sync max max_fc7 = nd.max(_fc7, axis=1, keepdims=True) max_fc7 = nd.reshape(max_fc7, -1) total_max_fc7 = self.get_ndarray( context=self.gpu, name='total_max_fc7_%d' % head_id, shape=(max_fc7.shape[0], self.size), dtype='float32') total_max_fc7[:] = 0 total_max_fc7[:, self.rank] = max_fc7 hvd.allreduce_(total_max_fc7, average=False) global_max_fc7 = self.get_ndarray( context=self.gpu, name='global_max_fc7_%d' % head_id, shape=(max_fc7.shape[0], 1), dtype='float32') nd.max(total_max_fc7, axis=1, keepdims=True, out=global_max_fc7) # Calculate prob _fc7_grad = nd.broadcast_sub(_fc7, global_max_fc7) _fc7_grad = nd.exp(_fc7_grad) # Calculate sum sum_fc7 = nd.sum(_fc7_grad, axis=1, keepdims=True) global_sum_fc7 = hvd.allreduce(sum_fc7, average=False) # Calculate grad _fc7_grad = nd.broadcast_div(_fc7_grad, global_sum_fc7) # Calculate loss tmp = _fc7_grad * _one_hot tmp = nd.sum(tmp, axis=1, keepdims=True) tmp = self.get_ndarray2(self.gpu, 'ctx_loss_%d' % head_id, tmp) tmp = hvd.allreduce(tmp, average=False) global_loss = -nd.mean(nd.log(tmp + 1e-30)) _fc7_grad = _fc7_grad - _one_hot # Backward _fc7.backward(out_grad=_fc7_grad) # Update center _weight_grad = _weight.grad self.memory_optimizer.update(weight=_weight, grad=_weight_grad, state=_weight_mom, learning_rate=self.memory_lr * self.memory_lr_scale_list[head_id]) if memory_bank.gpu: memory_bank.set( index=local_index, updated_weight=_weight, updated_weight_mom=_weight_mom) else: memory_bank.set( index=local_index, updated_weight = self.get_ndarray2(mx.cpu(), "cpu_weight_%d_%d" % (self.rank, head_id), _weight), updated_weight_mom = self.get_ndarray2(mx.cpu(), "cpu_weight_mom_%d_%d" % (self.rank, head_id), _weight_mom)) return _data.grad, global_loss def get_outputs(self, merge_multi_context=True): return self.backbone_module.get_outputs(merge_multi_context=merge_multi_context) def update(self): self.backbone_module.update() # self.bn_module.update() mx.nd.waitall() def bind(self, data_shapes_list=None, label_shapes_list=None, for_training=True, inputs_need_grad=False): assert data_shapes_list is not None and label_shapes_list is not None if self.binded: self.logger.warning('Already binded, ignoring bind()') return data_name = data_shapes_list[0][0][0] data_shapes = data_shapes_list[0][0][1] label_name = label_shapes_list[0][0][0] label_shapes = label_shapes_list[0][0][1] self.for_training = for_training self.inputs_need_grad = inputs_need_grad self.binded = True _backbone_data_shapes = [(data_name, (self.batch_size,) + data_shapes[1:])] _backbone_label_shapes = [(label_name, (self.batch_size,) + label_shapes[1:])] _bn_data_shapes = [(data_name, (self.batch_size * self.size, self.embedding_size))] self.backbone_module.bind( data_shapes=_backbone_data_shapes, label_shapes=_backbone_label_shapes, for_training=for_training, inputs_need_grad=inputs_need_grad) # self.bn_module.bind( # data_shapes=_bn_data_shapes, # for_training=for_training, # inputs_need_grad=True # ) def init_optimizer(self, optimizer_params, force_init=False): assert self.binded and self.params_initialized if self.optimizer_initialized and not force_init: self.logger.warning('optimizer already initialized, ignoring.') return # backbone # optimizer_backbone = DistributedOptimizer(LARS(**optimizer_params)) # optimizer_bn = DistributedOptimizer(LARS(**optimizer_params), prefix='bn_') optimizer_backbone = DistributedOptimizer(SGD(**optimizer_params)) self.backbone_module.init_optimizer( 'local', optimizer_backbone, force_init=force_init) # optimizer_bn = DistributedOptimizer(SGD(**optimizer_params), prefix='bn_') # self.bn_module.init_optimizer( # 'local', optimizer_bn, force_init=force_init) self.optimizer_initialized = True