def load_parameters(uuid_map=None,
                        parameters_file=None,
                        variable_constants_file=None,
                        mxnet_constants_file=None,
                        context=None,
                        dtype=None,
                        current_params=None):
        """
        Loads back a sest of InferenceParameters from files.
        :param parameters_file: These are the parameters of the previous inference algorithm.  These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved in a binary format.
        :param mxnet_constants_file: These are the constants in mxnet format from the previous inference algorithm. These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved in a binary format.
        :param variable_constants_file: These are the constants in primitive format from the previous inference algorithm.
        :type variable_constants_file: json dict of {uuid: constant_primitive}
        """
        def with_uuid_map(item, uuid_map):
            if uuid_map is not None:
                return uuid_map[item]
            else:
                return item

        ip = InferenceParameters(context=context, dtype=dtype)

        if parameters_file is not None:
            old_params = ndarray.load(parameters_file)
            mapped_params = {
                with_uuid_map(k, uuid_map): v
                for k, v in old_params.items()
            }

            new_paramdict = ParameterDict()
            if current_params is not None:
                new_paramdict.update(current_params)

            # Do this because we need to map the uuids to the new Model
            # before loading them into the ParamDict
            for name, mapped_param in mapped_params.items():
                new_paramdict[name]._load_init(mapped_param, ip.mxnet_context)
            ip._params = new_paramdict

        new_mxnet_constants = {}
        new_variable_constants = {}
        if variable_constants_file is not None:
            import json
            with open(variable_constants_file) as f:
                old_constants = json.load(f)
                new_variable_constants = {
                    with_uuid_map(k, uuid_map): v
                    for k, v in old_constants.items()
                }
        if mxnet_constants_file is not None:
            new_mxnet_constants = {
                with_uuid_map(k, uuid_map): v
                for k, v in ndarray.load(mxnet_constants_file).items()
            }
        ip._constants = {}
        ip._constants.update(new_variable_constants)
        ip._constants.update(new_mxnet_constants)
        return ip
    def collect_internal_parameters(self):
        """
        Return the parameters of the MXNet Gluon block that have *not* been set a prior distribution.

        :returns: the parameters of the MXNet Gluon block without a prior distribution.
        :rtype: MXNet.gluon.ParameterDict
        """
        params = ParameterDict()
        gluon_params = self.block.collect_params()
        params.update({var_name: gluon_params[var_name] for var_name, var in self._gluon_parameters.items() if var.type == VariableType.PARAMETER})
        return params
Beispiel #3
0
 def __init__(self, input_channels, context, RF_in_units, conv_input_shape=(96,96), train_RF=False) -> None:
     self._train_RF = train_RF
     self._lossfun = Lossfun(alpha= 1, beta_vgg=100, beta_pix= 1, context=context)
     
     self._rf_mapper = RFLayer(RF_in_units, conv_input_shape)
     
     self._network = Network(3, input_channels)
     params_to_train = ParameterDict()
     if train_RF:
         params_to_train.update(self._rf_mapper.collect_params())
     params_to_train.update(self._network.collect_params())
     self._trainer = Trainer(params_to_train, "adam", {"beta1": 0.5, "learning_rate": 0.0002})
    def load_parameters(uuid_map=None,
                        mxnet_parameters=None,
                        variable_constants=None,
                        mxnet_constants=None,
                        context=None, dtype=None,
                        current_params=None):
        """
        Loads back a set of InferenceParameters from files.
        :param mxnet_parameters: These are the parameters of
                                     the previous inference algorithm.
        These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_parameters: Dict of {uuid: mx.nd.array}
        :param mxnet_constants: These are the constants in mxnet format
                                    from the previous inference algorithm.
        These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_constants:  Dict of {uuid: mx.nd.array}
        :param variable_constants: These are the constants in
                                       primitive format from the previous
        inference algorithm.
        :type variable_constants: dict of {uuid: constant primitive}
        """
        def with_uuid_map(item, uuid_map):
            if uuid_map is not None:
                return uuid_map[item]
            else:
                return item
        ip = InferenceParameters(context=context, dtype=dtype)

        mapped_params = {with_uuid_map(k, uuid_map): v
                         for k, v in mxnet_parameters.items()}

        new_paramdict = ParameterDict()
        if current_params is not None:
            new_paramdict.update(current_params)

        # Do this because we need to map the uuids to the new Model
        # before loading them into the ParamDict
        for name, mapped_param in mapped_params.items():
            new_paramdict[name]._load_init(mapped_param, ip.mxnet_context)
        ip._params = new_paramdict

        new_mxnet_constants = {}
        new_variable_constants = {}
        new_variable_constants = {with_uuid_map(k, uuid_map): v
                                  for k, v in variable_constants.items()}
        new_mxnet_constants = {with_uuid_map(k, uuid_map): v
                               for k, v in mxnet_constants.items()}

        ip._constants = {}
        ip._constants.update(new_variable_constants)
        ip._constants.update(new_mxnet_constants)
        return ip
 def _get_decoder(self):
     output = nn.HybridSequential()
     with output.name_scope():
         if self._tie_weights:
             if self._shared_params is not None:
                 # self.embedding[0].params do not contain the bias, it
                 # may leave the decoder bias uninitialized. We resolve this
                 # issue by creating a new ParameterDict and stuffing
                 # every shared params into the ParameterDict.
                 shared_params = self.embedding[0].params
                 shared_params = ParameterDict(shared_params.prefix)
                 shared_params.update(self._shared_params)
                 output.add(nn.Dense(self._vocab_size, flatten=False,
                                     params=shared_params))
             else:
                 output.add(nn.Dense(self._vocab_size, flatten=False,
                                     params=self.embedding[0].params))
         else:
             output.add(nn.Dense(self._vocab_size, flatten=False))
     return output
Beispiel #6
0
 def collect_pparams(self, select=None):
     self._check_container_with_block()
     ret = ParameterDict(self.pparams.prefix)
     if not select:
         ret.update(self.pparams)
     else:
         pattern = re.compile(select)
         ret.update({
             name: value
             for name, value in self.pparams.items() if pattern.match(name)
         })
     for cld in self._children.values():
         try:
             ret.update(cld.collect_pparams(select=select))
         except AttributeError:
             pass
     return ret
Beispiel #7
0
class InferenceParameters(object):
    """
    The parameters and outcomes of an inference method.

    InferenceParameters is a pool of memory that contains a mapping from uuid to two types of memories
    (MXNet ParameterDict and Constants).

    :param constants: Specify a list of model variables as constants
    :type constants: {ModelComponent.uuid : mxnet.ndarray}
    :param dtype: data type for internal numerical representation
    :type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'}
    :param context: The MXNet context
    :type context: {mxnet.cpu or mxnet.gpu}
    """
    def __init__(self, constants=None, dtype=None, context=None):
        self.dtype = dtype if dtype is not None else get_default_dtype()
        self.mxnet_context = context if context is not None else get_default_device(
        )
        self._constants = {}
        self._var_ties = {}
        if constants is not None:
            constant_uuids = {(k.uuid if isinstance(k, ModelComponent) else k):
                              v
                              for k, v in constants.items()}
            self._constants.update(constant_uuids)
        self._params = ParameterDict()

    def update_constants(self, constants):
        """
        Update the constants.

        :param constants: The constants to be updated.
        :type constants: {Variable: float or MXNet NDArray}
        """
        self.constants.update({
            (k.uuid if isinstance(k, ModelComponent) else k): v
            for k, v in constants.items()
        })

    def initialize_params(self, graphs, observed_uuid):
        """
        :param graphs: a list of graphs in which the parameters will be optimized.
        :type graphs: a list of FactorGraph
        :param observed_uuid: Parameter Variables that are passed in directly as data, not to be inferred.
        :type observed_uuid: list, set
        """
        if self._params is not None:
            warnings.warn(
                "InferenceParameters has already been initialized.  The existing one will be overwritten."
            )

        self._params = ParameterDict()
        for g in graphs:
            # load in parameterdict from external gluon blocks.
            for f in g.functions.values():
                if isinstance(f, GluonFunctionEvaluation):
                    self._params.update(f.function.collect_gluon_parameters())

            for var in g.get_constants():
                self._constants[var.uuid] = var.constant

            excluded = set(self._constants.keys()).union(observed_uuid)
            for var in g.get_parameters(excluded=excluded,
                                        include_inherited=False):
                var_shape = realize_shape(var.shape, self._constants)
                init = initializer.Constant(var.initial_value_before_transformation) \
                    if var.initial_value is not None else None

                self._params.get(name=var.uuid,
                                 shape=var_shape,
                                 dtype=self.dtype,
                                 allow_deferred_init=True,
                                 init=init)
            for m in g.modules.values():
                m.initialize_hidden_parameters(self._params, excluded,
                                               self._constants)

        self._params.initialize(ctx=self.mxnet_context)

    def initialize_with_carryover_params(self, graphs, observed_uuid, var_ties,
                                         carryover_params):
        """
        :param graphs: a list of graphs in which the parameters will be optimized.
        :type graphs: a list of FactorGraph
        :param observed_uuid: Parameter Variables that are passed in directly as data, not to be inferred.
        :type observed_uuid: {UUID : mx.ndarray}
        :param var_ties: A dictionary of variable maps that are tied together and use the MXNet Parameter of the dict
        value's uuid.
        :type var_ties: { UUID to tie from : UUID to tie to }
        :param carryover_params: list of InferenceParameters containing the outcomes of previous inference algorithms.
        :type carryover_params: [InferenceParameters]
        """
        # TODO: var_ties is discarded at the moment.

        var_uuid = set()
        for g in graphs:
            var_uuid = var_uuid.union(set(g.variables.keys()))
            for m in g.modules.values():
                var_uuid = var_uuid.union(set(m.hidden_parameters))

        carryover_pairs = {}
        for carryover in carryover_params:
            for uuid, v in carryover.param_dict.items():
                if uuid in var_uuid:
                    if uuid in carryover_pairs:
                        warnings.warn(
                            'The variable with UUID ' + uuid +
                            ' exists in multiple carryover parameter sets.')
                    carryover_pairs[uuid] = v

        # self._var_ties = var_ties.copy()
        # for g in graphs:
        #     # TODO: check the behavior of var_ties in graph
        #     self._var_ties.update(g.var_ties)
        # for v_uuid in self.constants:
        #     if v_uuid in self._var_ties:
        #         del self._var_ties[v_uuid]

        observed_uuid = set(observed_uuid).union(carryover_pairs.keys())
        self.initialize_params(graphs, observed_uuid)

        # carryover_pairs = {
        #     to_var_uuid: carryover.param_dict[to_var_uuid]
        #     for from_var_uuid, to_var_uuid in self._var_ties.items()
        #     for carryover in carryover_params
        #     if to_var_uuid in carryover.param_dict}
        self._params.update(carryover_pairs)

    @property
    def param_dict(self):
        return self._params

    @property
    def constants(self):
        return self._constants

    @property
    def var_ties(self):
        return self._var_ties

    def __getitem__(self, key, ctx=None):
        if not isinstance(key, Variable):
            raise KeyError(
                "The access key of inference parameter needs to be Variable, but got "
                + str(type(key)) + ".")
        pkey = key.inherited_name if key.isInherited else key.uuid
        val = self._params.get(pkey).data(ctx)
        if key.transformation is not None:
            val = key.transformation.transform(val)
        return val

    def __setitem__(self, key, item):
        if not isinstance(key, Variable):
            raise KeyError(
                "The access key of inference parameter needs to be Variable, but get "
                + str(type(key)) + ".")

        if key.type == VariableType.PARAMETER:
            if key.transformation is not None:
                item = key.transformation.inverseTransform(item)
            self._params.get(key.uuid).set_data(item)
        elif key.type == VariableType.CONSTANT:
            self._params.get(key.uuid)._value = item

    # Override contains so that it doesn't use the __getitem__ method.
    def __contains__(self, k):
        return k in self.__dict__

    @staticmethod
    def load_parameters(uuid_map=None,
                        parameters_file=None,
                        variable_constants_file=None,
                        mxnet_constants_file=None,
                        context=None,
                        dtype=None,
                        current_params=None):
        """
        Loads back a sest of InferenceParameters from files.
        :param parameters_file: These are the parameters of the previous inference algorithm.
        These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved
        in a binary format.
        :param mxnet_constants_file: These are the constants in mxnet format from the previous inference algorithm.
        These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_constants_file: file saved down with mx.nd.save(), so a {uuid: mx.nd.array} mapping saved
        in a binary format.
        :param variable_constants_file: These are the constants in primitive format from the previous
        inference algorithm.
        :type variable_constants_file: json dict of {uuid: constant_primitive}
        """
        def with_uuid_map(item, uuid_map):
            if uuid_map is not None:
                return uuid_map[item]
            else:
                return item

        ip = InferenceParameters(context=context, dtype=dtype)

        if parameters_file is not None:
            old_params = ndarray.load(parameters_file)
            mapped_params = {
                with_uuid_map(k, uuid_map): v
                for k, v in old_params.items()
            }

            new_paramdict = ParameterDict()
            if current_params is not None:
                new_paramdict.update(current_params)

            # Do this because we need to map the uuids to the new Model
            # before loading them into the ParamDict
            for name, mapped_param in mapped_params.items():
                new_paramdict[name]._load_init(mapped_param, ip.mxnet_context)
            ip._params = new_paramdict

        new_mxnet_constants = {}
        new_variable_constants = {}
        if variable_constants_file is not None:
            import json
            with open(variable_constants_file) as f:
                old_constants = json.load(f)
                new_variable_constants = {
                    with_uuid_map(k, uuid_map): v
                    for k, v in old_constants.items()
                }
        if mxnet_constants_file is not None:
            mxnet_constants = ndarray.load(mxnet_constants_file)
            if isinstance(mxnet_constants, dict):
                new_mxnet_constants = {
                    with_uuid_map(k, uuid_map): v
                    for k, v in mxnet_constants.items()
                }
            else:
                new_mxnet_constants = {}
        ip._constants = {}
        ip._constants.update(new_variable_constants)
        ip._constants.update(new_mxnet_constants)
        return ip

    def save(self, prefix):
        """
        Saves the parameters and constants down to json files as maps from {uuid : value},
        where value is an mx.ndarray for parameters and either primitive number types or mx.ndarray for constants.
        Saves up to 3 files: prefix+["_params.json", "_variable_constants.json", "_mxnet_constants.json"]

        :param prefix: The directory and any appending tag for the files to save this Inference as.
        :type prefix: str , ex. "../saved_inferences/experiment_1"
        """
        param_file = prefix + "_params.json"
        variable_constants_file = prefix + "_variable_constants.json"
        mxnet_constants_file = prefix + "_mxnet_constants.json"
        to_save = {key: value._reduce() for key, value in self._params.items()}
        ndarray.save(param_file, to_save)

        mxnet_constants = {
            uuid: value
            for uuid, value in self._constants.items()
            if isinstance(value, mx.ndarray.ndarray.NDArray)
        }
        ndarray.save(mxnet_constants_file, mxnet_constants)

        variable_constants = {
            uuid: value
            for uuid, value in self._constants.items()
            if uuid not in mxnet_constants
        }
        import json
        with open(variable_constants_file, 'w') as f:
            json.dump(variable_constants, f, ensure_ascii=False)
Beispiel #8
0
def train_pcbrpp(cfg, logprint=print):
    cfg.ctx = mx.Context(cfg.device_type, cfg.device_id)

    # ==========================================================================
    # define train dataset, query dataset and test dataset
    # ==========================================================================
    traintransformer = ListTransformer(datasetroot=cfg.trainIMpath,
                                       resize_size=cfg.resize_size,
                                       crop_size=cfg.crop_size,
                                       istrain=True)
    querytransformer = Market1501_Transformer(datasetroot=cfg.queryIMpath,
                                              resize_size=cfg.resize_size,
                                              crop_size=cfg.crop_size,
                                              istrain=False)
    gallerytransformer = Market1501_Transformer(datasetroot=cfg.queryIMpath,
                                                resize_size=cfg.resize_size,
                                                crop_size=cfg.crop_size,
                                                istrain=False)
    traindataset = TextDataset(txtfilepath=cfg.trainList,
                               transform=traintransformer)
    querydataset = TextDataset(txtfilepath=cfg.queryList,
                               transform=querytransformer)
    gallerydataset = TextDataset(txtfilepath=cfg.galleryList,
                                 transform=gallerytransformer)
    train_iterator = DataLoader(traindataset,
                                num_workers=1,
                                shuffle=True,
                                last_batch='discard',
                                batch_size=cfg.batchsize)
    query_iterator = DataLoader(querydataset,
                                num_workers=1,
                                shuffle=True,
                                last_batch='keep',
                                batch_size=cfg.batchsize)
    gallery_iterator = DataLoader(gallerydataset,
                                  num_workers=1,
                                  shuffle=True,
                                  last_batch='keep',
                                  batch_size=cfg.batchsize)

    def test_iterator():
        for data in tqdm(query_iterator, ncols=80):
            if isinstance(data, (tuple, list)):
                data.append('query')
            else:
                data = (data, 'query')
            yield data
        for data in tqdm(gallery_iterator, ncols=80):
            if isinstance(data, (tuple, list)):
                data.append('gallery')
            else:
                data = (data, 'gallery')
            yield data

    # ==========================================================================

    # ==========================================================================
    # define model and trainer list, lr_scheduler
    # ==========================================================================
    Net = PCBRPPNet(basenetwork=cfg.basenet,
                    pretrained=cfg.base_pretrained,
                    feature_channels=cfg.feature_channels,
                    classes=cfg.classes_num,
                    laststride=cfg.laststride,
                    withpcb=cfg.withpcb,
                    partnum=cfg.partnum,
                    feature_weight_share=cfg.feature_weight_share,
                    withrpp=cfg.withrpp)
    if cfg.pretrain_path is not None:
        Net.load_params(cfg.pretrain_path,
                        ctx=mx.cpu(),
                        allow_missing=True,
                        ignore_extra=True)
    Net.collect_params().reset_ctx(cfg.ctx)

    trainers = []
    if cfg.base_train:
        base_params = Net.conv.collect_params()
        base_optimizer_params = {
            'learning_rate': cfg.base_learning_rate,
            'wd': cfg.weight_decay,
            'momentum': cfg.momentum,
            'multi_precision': True
        }
        basetrainer = Trainer(base_params,
                              optimizer=cfg.optim,
                              optimizer_params=base_optimizer_params)
        trainers.append(basetrainer)
    if cfg.tail_train:
        tail_params = ParameterDict()
        if (not cfg.withpcb) or cfg.feature_weight_share:
            tail_params.update(Net.feature.collect_params())
            # tail_params.update(Net.feature_.collect_params())
            tail_params.update(Net.classifier.collect_params())
        else:
            for pn in range(cfg.partnum):
                tail_params.update(
                    getattr(Net, 'feature%d' % (pn + 1)).collect_params())
                # tail_params.update(
                #     getattr(Net, 'feature%d_' % (pn+1)).collect_params())
                tail_params.update(
                    getattr(Net, 'classifier%d' % (pn + 1)).collect_params())
        tail_optimizer_params = {
            'learning_rate': cfg.tail_learning_rate,
            'wd': cfg.weight_decay,
            'momentum': cfg.momentum,
            'multi_precision': True
        }
        tailtrainer = Trainer(tail_params,
                              optimizer=cfg.optim,
                              optimizer_params=tail_optimizer_params)
        trainers.append(tailtrainer)
    if cfg.withrpp and cfg.rpp_train:
        rpp_params = Net.rppscore.collect_params()
        rpp_optimizer_params = {
            'learning_rate': cfg.rpp_learning_rate,
            'wd': cfg.weight_decay,
            'momentum': cfg.momentum,
            'multi_precision': True
        }
        rpptrainer = Trainer(rpp_params,
                             optimizer=cfg.optim,
                             optimizer_params=rpp_optimizer_params)
        trainers.append(rpptrainer)
    if len(trainers) == 0:
        raise "There is no params for training."
    lr_scheduler = MultiStepListScheduler(trainers,
                                          milestones=cfg.milestones,
                                          gamma=cfg.gamma)
    # ==========================================================================

    # ==========================================================================
    # metric, loss, saver define
    # ==========================================================================
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    loss_metric = Loss()
    if cfg.partnum is not None:
        train_accuracy_metrics = [Accuracy() for _ in range(cfg.partnum)]
    else:
        train_accuracy_metric = Accuracy()
    reid_metric = ReID_Metric(isnorm=True)

    save_name = ""
    if not cfg.withpcb:
        save_name = "IDE"
    elif not cfg.withrpp:
        save_name = "NORPP_%dPart" % (cfg.partnum)
    else:
        if not cfg.tail_train and not cfg.base_train:
            save_name = "WITHRPP_%dPart" % (cfg.partnum)
    if cfg.withpcb and cfg.feature_weight_share:
        save_name += "_FEASHARE"
    net_saver = Best_Evaluation_Saver(save_dir=cfg.snapdir,
                                      save_name=save_name,
                                      reverse=False)
    # ==========================================================================
    logprint(Net)

    # ==========================================================================
    # process functions
    # ==========================================================================

    def reset_metrics():
        loss_metric.reset()
        if cfg.partnum is not None:
            for metric in train_accuracy_metrics:
                metric.reset()
        else:
            train_accuracy_metric.reset()
        reid_metric.reset()

    def on_start(state):
        pass
        if state['train']:
            state['store_iterator'] = state['iterator']

    def on_start_epoch(state):
        lr_scheduler.step()
        reset_metrics()
        if state['train']:
            state['iterator'] = tqdm(state['store_iterator'], ncols=80)

    def on_sample(state):
        pass

    def test_process(sample):
        img, cam, label, ds = sample
        img = img.as_in_context(cfg.ctx)
        ID1, Fea1 = Net(img)
        if cfg.partnum is not None:
            Fea1 = ndarray.concat(*Fea1, dim=-1)
        img = img.flip(axis=3)
        ID2, Fea2 = Net(img)
        if cfg.partnum is not None:
            Fea2 = ndarray.concat(*Fea2, dim=-1)
        return None, Fea1 + Fea2

    def train_process(sample):
        data, label = sample
        data = data.as_in_context(cfg.ctx)
        label = label.as_in_context(cfg.ctx)
        with autograd.record():
            with autograd.record(train_mode=cfg.base_train):
                x = Net.base_forward(data)
            if cfg.withpcb:
                with autograd.record(train_mode=cfg.rpp_train):
                    x = Net.split_forward(x)
            with autograd.record(train_mode=cfg.tail_train):
                ID, Fea = Net.tail_forward(x)
            # ID, Fea = Net(data)
            if isinstance(ID, list):
                losses = [softmax_cross_entropy(id_, label) for id_ in ID]
                loss = ndarray.stack(*losses, axis=0).mean(axis=0)
            else:
                loss = softmax_cross_entropy(ID, label)
        loss.backward()
        for trainer in trainers:
            trainer.step(data.shape[0])
        return loss, ID

    def on_forward(state):
        if state['train']:
            img, label = state['sample']
            loss_metric.update(None, state['loss'])
            if cfg.partnum is not None:
                for metric, id_ in zip(train_accuracy_metrics,
                                       state['output']):
                    metric.update(preds=id_, labels=label)
            else:
                train_accuracy_metric.update(preds=state['output'],
                                             labels=label)
        else:
            img, cam, label, ds = state['sample']
            if cfg.feature_norm:
                fnorm = ndarray.power(state['output'], 2)
                fnorm = ndarray.sqrt(ndarray.sum(fnorm, axis=-1,
                                                 keepdims=True))
                state['output'] = state['output'] / fnorm
            reid_metric.update(state['output'], cam, label, ds)

    def on_end_iter(state):
        pass

    def on_end_epoch(state):
        if state['train']:
            logprint("[Epoch %d] train loss: %.6f" %
                     (state['epoch'], loss_metric.get()[1]))
            if cfg.partnum is not None:
                for idx, metric in enumerate(train_accuracy_metrics):
                    logprint("[Epoch %d] part No.%d train accuracy: %.2f%%" %
                             (state['epoch'], idx + 1, metric.get()[1] * 100))
            else:
                logprint(
                    "[Epoch %d] train accuracy: %.2f%%" %
                    (state['epoch'], train_accuracy_metric.get()[1] * 100))
            if state['epoch'] % cfg.val_epochs == 0:
                reset_metrics()
                processor.test(test_process, test_iterator())
                CMC, mAP = reid_metric.get()[1]
                logprint(
                    "[Epoch %d] CMC1: %.2f%% CMC5: %.2f%% CMC10: %.2f%% CMC20: %.2f%% mAP: %.2f%%"
                    % (state['epoch'], CMC[0] * 100, CMC[4] * 100,
                       CMC[9] * 100, CMC[19] * 100, mAP * 100))
                if state['epoch'] % cfg.snap_epochs == 0:
                    net_saver.save(Net, CMC[0])

    def on_end(state):
        pass

    processor = EpochProcessor()
    processor.hooks['on_start'] = on_start
    processor.hooks['on_start_epoch'] = on_start_epoch
    processor.hooks['on_sample'] = on_sample
    processor.hooks['on_forward'] = on_forward
    processor.hooks['on_end_iter'] = on_end_iter
    processor.hooks['on_end_epoch'] = on_end_epoch
    processor.hooks['on_end'] = on_end

    processor.train(train_process, train_iterator, cfg.max_epochs)
class InferenceParameters(object):
    """
    The parameters and outcomes of an inference method.

    InferenceParameters is a pool of memory that contains a mapping from uuid to two types of memories
    (MXNet ParameterDict and Constants).

    :param constants: Specify a list of model variables as constants
    :type constants: {ModelComponent.uuid : mxnet.ndarray}
    :param dtype: data type for internal numerical representation
    :type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'}
    :param context: The MXNet context
    :type context: {mxnet.cpu or mxnet.gpu}
    """
    def __init__(self, constants=None, dtype=None, context=None):
        self.dtype = dtype if dtype is not None else get_default_dtype()
        self.mxnet_context = context if context is not None else get_default_device()
        self._constants = {}
        self._var_ties = {}
        if constants is not None:
            constant_uuids = {
                (k.uuid if isinstance(k, ModelComponent) else k): v
                for k, v in constants.items()}
            self._constants.update(constant_uuids)
        self._params = ParameterDict()

    def update_constants(self, constants):
        """
        Update the constants.

        :param constants: The constants to be updated.
        :type constants: {Variable: float or MXNet NDArray}
        """
        self.constants.update({
            (k.uuid if isinstance(k, ModelComponent) else k): v
            for k, v in constants.items()})

    def initialize_params(self, graphs, observed_uuid):
        """
        :param graphs: a list of graphs in which the parameters will be optimized.
        :type graphs: a list of FactorGraph
        :param observed_uuid: Parameter Variables that are passed in directly as data, not to be inferred.
        :type observed_uuid: list, set
        """
        if self._params is not None:
            warnings.warn("InferenceParameters has already been initialized.  The existing one will be overwritten.")

        self._params = ParameterDict()
        for g in graphs:
            for var in g.get_constants():
                self._constants[var.uuid] = var.constant

            excluded = set(self._constants.keys()).union(observed_uuid)
            for var in g.get_parameters(excluded=excluded):
                var_shape = realize_shape(var.shape, self._constants)
                init = initializer.Constant(var.initial_value_before_transformation) \
                    if var.initial_value is not None else None

                self._params.get(name=var.uuid, shape=var_shape,
                                 dtype=self.dtype,
                                 allow_deferred_init=True, init=init)
            for m in g.modules.values():
                m.initialize_hidden_parameters(self._params, excluded, self._constants)

        self._params.initialize(ctx=self.mxnet_context)

    def initialize_with_carryover_params(self, graphs, observed_uuid, var_ties,
                                         carryover_params):
        """
        :param graphs: a list of graphs in which the parameters will be optimized.
        :type graphs: a list of FactorGraph
        :param observed_uuid: Parameter Variables that are passed in directly as data, not to be inferred.
        :type observed_uuid: {UUID : mx.ndarray}
        :param var_ties: A dictionary of variable maps that are tied together and use the MXNet Parameter of the dict
        value's uuid.
        :type var_ties: { UUID to tie from : UUID to tie to }
        :param carryover_params: list of InferenceParameters containing the outcomes of previous inference algorithms.
        :type carryover_params: [InferenceParameters]
        """
        # TODO: var_ties is discarded at the moment.

        var_uuid = set()
        for g in graphs:
            var_uuid = var_uuid.union(set(g.variables.keys()))
            for m in g.modules.values():
                var_uuid = var_uuid.union(set(m.hidden_parameters))

        carryover_pairs = {}
        for carryover in carryover_params:
            for uuid, v in carryover.param_dict.items():
                if uuid in var_uuid:
                    if uuid in carryover_pairs:
                        warnings.warn('The variable with UUID '+uuid+' exists in multiple carryover parameter sets.')
                    carryover_pairs[uuid] = v

        # self._var_ties = var_ties.copy()
        # for g in graphs:
        #     # TODO: check the behavior of var_ties in graph
        #     self._var_ties.update(g.var_ties)
        # for v_uuid in self.constants:
        #     if v_uuid in self._var_ties:
        #         del self._var_ties[v_uuid]

        observed_uuid = set(observed_uuid).union(carryover_pairs.keys())
        self.initialize_params(graphs, observed_uuid)

        # carryover_pairs = {
        #     to_var_uuid: carryover.param_dict[to_var_uuid]
        #     for from_var_uuid, to_var_uuid in self._var_ties.items()
        #     for carryover in carryover_params
        #     if to_var_uuid in carryover.param_dict}
        self._params.update(carryover_pairs)

    def fix_all(self):
        for p in self.param_dict.values():
            p.grad_req = 'null'

    @property
    def param_dict(self):
        return self._params

    @property
    def constants(self):
        return self._constants

    @property
    def var_ties(self):
        return self._var_ties

    def __getitem__(self, key, ctx=None):
        if not isinstance(key, Variable):
            raise KeyError("The access key of inference parameter needs to be Variable, but got "+str(type(key))+".")
        val = self._params.get(key.uuid).data(ctx)
        if key.transformation is not None:
            val = key.transformation.transform(val)
        return val

    def __setitem__(self, key, item):
        if not isinstance(key, Variable):
            raise KeyError("The access key of inference parameter needs to be Variable, but get "+str(type(key))+".")

        if key.type == VariableType.PARAMETER:
            if key.transformation is not None:
                item = key.transformation.inverseTransform(item)
            self._params.get(key.uuid).set_data(item)
        elif key.type == VariableType.CONSTANT:
            self._params.get(key.uuid)._value = item

    # Override contains so that it doesn't use the __getitem__ method.
    def __contains__(self, k):
        return k in self.__dict__

    @staticmethod
    def load_parameters(uuid_map=None,
                        mxnet_parameters=None,
                        variable_constants=None,
                        mxnet_constants=None,
                        context=None, dtype=None,
                        current_params=None):
        """
        Loads back a set of InferenceParameters from files.
        :param mxnet_parameters: These are the parameters of
                                     the previous inference algorithm.
        These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_parameters: Dict of {uuid: mx.nd.array}
        :param mxnet_constants: These are the constants in mxnet format
                                    from the previous inference algorithm.
        These are in a {uuid: mx.nd.array} mapping.
        :type mxnet_constants:  Dict of {uuid: mx.nd.array}
        :param variable_constants: These are the constants in
                                       primitive format from the previous
        inference algorithm.
        :type variable_constants: dict of {uuid: constant primitive}
        """
        def with_uuid_map(item, uuid_map):
            if uuid_map is not None:
                return uuid_map[item]
            else:
                return item
        ip = InferenceParameters(context=context, dtype=dtype)

        mapped_params = {with_uuid_map(k, uuid_map): v
                         for k, v in mxnet_parameters.items()}

        new_paramdict = ParameterDict()
        if current_params is not None:
            new_paramdict.update(current_params)

        # Do this because we need to map the uuids to the new Model
        # before loading them into the ParamDict
        for name, mapped_param in mapped_params.items():
            new_paramdict[name]._load_init(mapped_param, ip.mxnet_context)
        ip._params = new_paramdict

        new_mxnet_constants = {}
        new_variable_constants = {}
        new_variable_constants = {with_uuid_map(k, uuid_map): v
                                  for k, v in variable_constants.items()}
        new_mxnet_constants = {with_uuid_map(k, uuid_map): v
                               for k, v in mxnet_constants.items()}

        ip._constants = {}
        ip._constants.update(new_variable_constants)
        ip._constants.update(new_mxnet_constants)
        return ip

    def get_serializable(self):
        """
        Returns three dicts:
         1. MXNet parameters {uuid: mxnet parameters, mx.nd.array}.
         2. MXNet constants {uuid: mxnet parameter (only constant types), mx.nd.array}
         3. Other constants {uuid: primitive numeric types (int, float)}
         :returns: Three dictionaries: MXNet parameters, MXNet constants, and other constants (in that order)
         :rtypes: {uuid: mx.nd.array}, {uuid: mx.nd.array}, {uuid: primitive (int/float)}
        """

        mxnet_parameters = {key: value._reduce() for key, value in self._params.items()}

        mxnet_constants = {uuid: value
                           for uuid, value in self._constants.items()
                           if isinstance(value, mx.ndarray.ndarray.NDArray)}

        variable_constants = {uuid: value
                              for uuid, value in self._constants.items()
                              if uuid not in mxnet_constants}

        return mxnet_parameters, mxnet_constants, variable_constants