Example #1
0
def create_mxnet_dict(layer_mapping, var_dict):
    mxnet_dict = {}
    for vn in var_dict:
        sp = vn.split("/")
        if sp[0] not in layer_mapping:
            logger.info("warning,", vn, "not found in mxnet model")
            continue
        layer = layer_mapping[sp[0]]
        if "bn" in sp[1]:
            if isinstance(layer, list):
                layer = layer[0]
            layer = layer.replace("res", "bn")
            if sp[2] == "beta":
                postfix = "_beta"
            elif sp[2] == "gamma":
                postfix = "_gamma"
            elif sp[2] == "mean_ema":
                postfix = "_moving_mean"
            elif sp[2] == "var_ema":
                postfix = "_moving_var"
            else:
                assert False, sp
        else:
            if isinstance(layer, list):
                layer = layer[1]
            postfix = "_weight"

        if "ema" in vn:
            layer = "aux:" + layer
        else:
            layer = "arg:" + layer

        if sp[1] == "W0":
            branch = "_branch1"
        elif sp[1] == "W1":
            branch = "_branch2a"
        elif sp[1] == "W2":
            branch = "_branch2b1"
        elif sp[1] == "W3":
            branch = "_branch2b2"
        elif sp[1] == "W":
            branch = ""
        elif sp[1] == "bn0":
            branch = "_branch2a"
        elif sp[1] == "bn2":
            branch = "_branch2b1"
        elif sp[1] == "bn3":
            branch = "_branch2b2"
        # for collapse
        elif sp[1] == "bn":
            branch = ""
        elif sp[1] == "b":
            branch = ""
            postfix = "_bias"
        else:
            assert False, sp

        mxnet_dict[vn] = layer + branch + postfix
    return mxnet_dict
Example #2
0
 def eval(self):
     start = time.time()
     valid_loss, measures = self.run_epoch(self.trainer.validation_step,
                                           self.valid_data, 0)
     end = time.time()
     elapsed = end - start
     valid_error_string = Measures.get_error_string(measures, "valid")
     logger.info(log.v1, "eval finished. elapsed:", elapsed, "valid_score:",
                 valid_loss, valid_error_string)
Example #3
0
    def try_load_weights(self):
        fn = None
        if self.load != "":
            fn = self.load.replace(".index", "")
        else:
            files = sorted(glob.glob(self.model_dir + self.model + "-*.index"))
            if len(files) > 0:
                fn = files[-1].replace(".index", "")

        if fn is not None:
            logger.info(log.v1, "loading model from", fn)
            self.saver.restore(self.session, fn)
            if self.model == fn.split("/")[-2]:
                self.start_epoch = int(fn.split("-")[-1])
                logger.info(log.v1, "starting from epoch",
                            self.start_epoch + 1)
        elif self.load_init != "":
            if self.load_init.endswith(".pickle"):
                logger.info(
                    log.v1,
                    "trying to initialize model from wider-or-deeper mxnet model",
                    self.load_init)
                load_wider_or_deeper_mxnet_model(self.load_init, self.session)
            else:
                fn = self.load_init
                logger.info(log.v1, "initializing model from", fn)
                assert self.load_init_saver is not None
                self.load_init_saver.restore(self.session, fn)
Example #4
0
 def _create_load_init_saver(self):
     if self.load_init != "" and not self.load_init.endswith(".pickle"):
         vars_file = [x[0] for x in list_variables(self.load_init)]
         vars_model = tf.global_variables()
         assert all([x.name.endswith(":0") for x in vars_model])
         vars_intersection = [
             x for x in vars_model if x.name[:-2] in vars_file
         ]
         vars_missing = [
             x for x in vars_model if x.name[:-2] not in vars_file
         ]
         if len(vars_missing) > 0:
             logger.info(
                 log.v1,
                 "the following variables will not be initialized since they are not present in the",
                 " initialization model", [v.name for v in vars_missing])
         return tf.train.Saver(var_list=vars_intersection)
     else:
         return None
Example #5
0
 def train(self):
     assert self.need_train
     logger.info(log.v1, "starting training")
     for epoch in range(self.start_epoch, self.num_epochs):
         start = time.time()
         train_loss, train_measures = self.run_epoch(
             self.trainer.train_step, self.train_data, epoch)
         valid_loss, valid_measures = self.run_epoch(
             self.trainer.validation_step, self.valid_data, epoch)
         end = time.time()
         elapsed = end - start
         train_error_string = Measures.get_error_string(
             train_measures, "train")
         valid_error_string = Measures.get_error_string(
             valid_measures, "valid")
         logger.info("epoch", epoch + 1, "finished. elapsed:",
                     "%.5f" % elapsed, "train_score:", "%.5f" % train_loss,
                     train_error_string, "valid_score:", valid_loss,
                     valid_error_string)
         if self.save:
             self.save_model(epoch + 1)
Example #6
0
    def __init__(self, config):
        self.config = config
        self.dataset = config.unicode("dataset").lower()
        self.load_init = config.unicode("load_init", "")
        self.load = config.unicode("load", "")
        self.task = config.unicode("task", "train")
        self.use_partialflow = config.bool("use_partialflow", False)
        self.do_oneshot_or_online_or_offline = self.task in ("oneshot_forward",
                                                             "oneshot",
                                                             "online",
                                                             "offline")
        if self.do_oneshot_or_online_or_offline:
            assert config.int("batch_size_eval", 1) == 1
        self.need_train = self.task == "train" or self.do_oneshot_or_online_or_offline or self.task == "forward_train"
        self.session = tf.InteractiveSession(config=tf.ConfigProto(
            allow_soft_placement=True))
        self.coordinator = tf.train.Coordinator()
        self.valid_data = load_dataset(config, "valid", self.session,
                                       self.coordinator)
        if self.need_train:
            self.train_data = load_dataset(config, "train", self.session,
                                           self.coordinator)

        self.num_epochs = config.int("num_epochs", 1000)
        self.model = config.unicode("model")
        self.model_base_dir = config.dir("model_dir", "models")
        self.model_dir = self.model_base_dir + self.model + "/"
        self.save = config.bool("save", True)

        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.start_epoch = 0
        reuse_variables = None
        if self.need_train:
            freeze_batchnorm = config.bool("freeze_batchnorm", False)
            self.train_network = Network(
                config,
                self.train_data,
                self.global_step,
                training=True,
                use_partialflow=self.use_partialflow,
                do_oneshot=self.do_oneshot_or_online_or_offline,
                freeze_batchnorm=freeze_batchnorm,
                name="trainnet")
            reuse_variables = True
        else:
            self.train_network = None
        with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
            self.test_network = Network(
                config,
                self.valid_data,
                self.global_step,
                training=False,
                do_oneshot=self.do_oneshot_or_online_or_offline,
                use_partialflow=False,
                freeze_batchnorm=True,
                name="testnet")
        logger.info("number of parameters:",
                    "{:,}".format(self.test_network.n_params))
        self.trainer = Trainer(config, self.train_network, self.test_network,
                               self.global_step, self.session)
        self.saver = tf.train.Saver(max_to_keep=0, pad_step_number=True)
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        tf.train.start_queue_runners(self.session)
        self.load_init_saver = self._create_load_init_saver()
        if not self.do_oneshot_or_online_or_offline:
            self.try_load_weights()
Example #7
0
from logging_utils import logger

logger.info("This is an info")
logger.debug("Enter a debug message")
logger.warning("This is a warning message")
logger.error("This is an error")
Example #8
0
def load_wider_or_deeper_mxnet_model(model_path, session):
    params = pickle.load(open(model_path))
    variables = tf.global_variables()

    model_name = model_path.split("/")[-1]
    if model_name.startswith("ilsvrc"):
        layer_mapping = {
            "res0": "res2a",
            "res1": "res2b1",
            "res2": "res2b2",
            "res3": "res3a",
            "res4": "res3b1",
            "res5": "res3b2",
            "res6": "res4a",
            "res7": "res4b1",
            "res8": "res4b2",
            "res9": "res4b3",
            "res10": "res4b4",
            "res11": "res4b5",
            "res12": "res5a",
            "res13": "res5b1",
            "res14": "res5b2",
            "res15": "res6a",
            "res16": "res7a",
            "output": "linear1000",
            "conv0": "conv1a",
            "collapse": "bn7"
        }
    elif model_name.startswith("ade"):
        layer_mapping = {
            "res0": "res2a",
            "res1": "res2b1",
            "res2": "res2b2",
            "res3": "res3a",
            "res4": "res3b1",
            "res5": "res3b2",
            "res6": "res4a",
            "res7": "res4b1",
            "res8": "res4b2",
            "res9": "res4b3",
            "res10": "res4b4",
            "res11": "res4b5",
            "res12": "res5a",
            "res13": "res5b1",
            "res14": "res5b2",
            "res15": "res6a",
            "res16": "res7a",
            "output": "linear150",
            "conv0": "conv1a",
            "conv1": ["bn7", "conv6a"]
        }
    elif model_name.startswith("voc"):
        layer_mapping = {
            "res0": "res2a",
            "res1": "res2b1",
            "res2": "res2b2",
            "res3": "res3a",
            "res4": "res3b1",
            "res5": "res3b2",
            "res6": "res4a",
            "res7": "res4b1",
            "res8": "res4b2",
            "res9": "res4b3",
            "res10": "res4b4",
            "res11": "res4b5",
            "res12": "res5a",
            "res13": "res5b1",
            "res14": "res5b2",
            "res15": "res6a",
            "res16": "res7a",
            "output": "linear21",
            "conv0": "conv1a",
            "conv1": ["bn7", "conv6a"]
        }
    else:
        assert False, model_name

    # from str (without :0) to var
    var_dict = {
        v.name[:-2]: v
        for v in variables if "Adam" not in v.name and "_power" not in v.name
        and "global_step" not in v.name
    }

    # from our var name to mxnet var name
    mxnet_dict = create_mxnet_dict(layer_mapping, var_dict)

    for k, v in mxnet_dict.items():
        assert v in params, (k, v)

    # use a placeholder to avoid memory issues
    placeholder = tf.placeholder(tf.float32)
    for k, v in mxnet_dict.items():
        val = params[v]
        if val.ndim == 1:
            pass
        elif val.ndim == 2:
            val = numpy.swapaxes(val, 0, 1)
        elif val.ndim == 4:
            val = numpy.moveaxis(val, [0, 1, 2, 3], [3, 2, 0, 1])
        else:
            assert False, val.ndim
        var = var_dict[k]
        if var.get_shape() == val.shape:
            op = tf.assign(var, placeholder)
            session.run([op], feed_dict={placeholder: val})
        elif k.startswith("conv0"):
            logger.info("warning, sizes for", k,
                        "do not match, initializing matching part assuming",
                        "the first 3 dimensions are RGB")
            val_new = session.run(var)
            val_new[..., :3, :] = val
            op = tf.assign(var, placeholder)
            session.run([op], feed_dict={placeholder: val_new})
        else:
            logger.info("skipping", k, "since the shapes do not match:",
                        var.get_shape(), "and", val.shape)