def create_mxnet_dict(layer_mapping, var_dict): mxnet_dict = {} for vn in var_dict: sp = vn.split("/") if sp[0] not in layer_mapping: logger.info("warning,", vn, "not found in mxnet model") continue layer = layer_mapping[sp[0]] if "bn" in sp[1]: if isinstance(layer, list): layer = layer[0] layer = layer.replace("res", "bn") if sp[2] == "beta": postfix = "_beta" elif sp[2] == "gamma": postfix = "_gamma" elif sp[2] == "mean_ema": postfix = "_moving_mean" elif sp[2] == "var_ema": postfix = "_moving_var" else: assert False, sp else: if isinstance(layer, list): layer = layer[1] postfix = "_weight" if "ema" in vn: layer = "aux:" + layer else: layer = "arg:" + layer if sp[1] == "W0": branch = "_branch1" elif sp[1] == "W1": branch = "_branch2a" elif sp[1] == "W2": branch = "_branch2b1" elif sp[1] == "W3": branch = "_branch2b2" elif sp[1] == "W": branch = "" elif sp[1] == "bn0": branch = "_branch2a" elif sp[1] == "bn2": branch = "_branch2b1" elif sp[1] == "bn3": branch = "_branch2b2" # for collapse elif sp[1] == "bn": branch = "" elif sp[1] == "b": branch = "" postfix = "_bias" else: assert False, sp mxnet_dict[vn] = layer + branch + postfix return mxnet_dict
def eval(self): start = time.time() valid_loss, measures = self.run_epoch(self.trainer.validation_step, self.valid_data, 0) end = time.time() elapsed = end - start valid_error_string = Measures.get_error_string(measures, "valid") logger.info(log.v1, "eval finished. elapsed:", elapsed, "valid_score:", valid_loss, valid_error_string)
def try_load_weights(self): fn = None if self.load != "": fn = self.load.replace(".index", "") else: files = sorted(glob.glob(self.model_dir + self.model + "-*.index")) if len(files) > 0: fn = files[-1].replace(".index", "") if fn is not None: logger.info(log.v1, "loading model from", fn) self.saver.restore(self.session, fn) if self.model == fn.split("/")[-2]: self.start_epoch = int(fn.split("-")[-1]) logger.info(log.v1, "starting from epoch", self.start_epoch + 1) elif self.load_init != "": if self.load_init.endswith(".pickle"): logger.info( log.v1, "trying to initialize model from wider-or-deeper mxnet model", self.load_init) load_wider_or_deeper_mxnet_model(self.load_init, self.session) else: fn = self.load_init logger.info(log.v1, "initializing model from", fn) assert self.load_init_saver is not None self.load_init_saver.restore(self.session, fn)
def _create_load_init_saver(self): if self.load_init != "" and not self.load_init.endswith(".pickle"): vars_file = [x[0] for x in list_variables(self.load_init)] vars_model = tf.global_variables() assert all([x.name.endswith(":0") for x in vars_model]) vars_intersection = [ x for x in vars_model if x.name[:-2] in vars_file ] vars_missing = [ x for x in vars_model if x.name[:-2] not in vars_file ] if len(vars_missing) > 0: logger.info( log.v1, "the following variables will not be initialized since they are not present in the", " initialization model", [v.name for v in vars_missing]) return tf.train.Saver(var_list=vars_intersection) else: return None
def train(self): assert self.need_train logger.info(log.v1, "starting training") for epoch in range(self.start_epoch, self.num_epochs): start = time.time() train_loss, train_measures = self.run_epoch( self.trainer.train_step, self.train_data, epoch) valid_loss, valid_measures = self.run_epoch( self.trainer.validation_step, self.valid_data, epoch) end = time.time() elapsed = end - start train_error_string = Measures.get_error_string( train_measures, "train") valid_error_string = Measures.get_error_string( valid_measures, "valid") logger.info("epoch", epoch + 1, "finished. elapsed:", "%.5f" % elapsed, "train_score:", "%.5f" % train_loss, train_error_string, "valid_score:", valid_loss, valid_error_string) if self.save: self.save_model(epoch + 1)
def __init__(self, config): self.config = config self.dataset = config.unicode("dataset").lower() self.load_init = config.unicode("load_init", "") self.load = config.unicode("load", "") self.task = config.unicode("task", "train") self.use_partialflow = config.bool("use_partialflow", False) self.do_oneshot_or_online_or_offline = self.task in ("oneshot_forward", "oneshot", "online", "offline") if self.do_oneshot_or_online_or_offline: assert config.int("batch_size_eval", 1) == 1 self.need_train = self.task == "train" or self.do_oneshot_or_online_or_offline or self.task == "forward_train" self.session = tf.InteractiveSession(config=tf.ConfigProto( allow_soft_placement=True)) self.coordinator = tf.train.Coordinator() self.valid_data = load_dataset(config, "valid", self.session, self.coordinator) if self.need_train: self.train_data = load_dataset(config, "train", self.session, self.coordinator) self.num_epochs = config.int("num_epochs", 1000) self.model = config.unicode("model") self.model_base_dir = config.dir("model_dir", "models") self.model_dir = self.model_base_dir + self.model + "/" self.save = config.bool("save", True) self.global_step = tf.Variable(0, name='global_step', trainable=False) self.start_epoch = 0 reuse_variables = None if self.need_train: freeze_batchnorm = config.bool("freeze_batchnorm", False) self.train_network = Network( config, self.train_data, self.global_step, training=True, use_partialflow=self.use_partialflow, do_oneshot=self.do_oneshot_or_online_or_offline, freeze_batchnorm=freeze_batchnorm, name="trainnet") reuse_variables = True else: self.train_network = None with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables): self.test_network = Network( config, self.valid_data, self.global_step, training=False, do_oneshot=self.do_oneshot_or_online_or_offline, use_partialflow=False, freeze_batchnorm=True, name="testnet") logger.info("number of parameters:", "{:,}".format(self.test_network.n_params)) self.trainer = Trainer(config, self.train_network, self.test_network, self.global_step, self.session) self.saver = tf.train.Saver(max_to_keep=0, pad_step_number=True) tf.global_variables_initializer().run() tf.local_variables_initializer().run() tf.train.start_queue_runners(self.session) self.load_init_saver = self._create_load_init_saver() if not self.do_oneshot_or_online_or_offline: self.try_load_weights()
from logging_utils import logger logger.info("This is an info") logger.debug("Enter a debug message") logger.warning("This is a warning message") logger.error("This is an error")
def load_wider_or_deeper_mxnet_model(model_path, session): params = pickle.load(open(model_path)) variables = tf.global_variables() model_name = model_path.split("/")[-1] if model_name.startswith("ilsvrc"): layer_mapping = { "res0": "res2a", "res1": "res2b1", "res2": "res2b2", "res3": "res3a", "res4": "res3b1", "res5": "res3b2", "res6": "res4a", "res7": "res4b1", "res8": "res4b2", "res9": "res4b3", "res10": "res4b4", "res11": "res4b5", "res12": "res5a", "res13": "res5b1", "res14": "res5b2", "res15": "res6a", "res16": "res7a", "output": "linear1000", "conv0": "conv1a", "collapse": "bn7" } elif model_name.startswith("ade"): layer_mapping = { "res0": "res2a", "res1": "res2b1", "res2": "res2b2", "res3": "res3a", "res4": "res3b1", "res5": "res3b2", "res6": "res4a", "res7": "res4b1", "res8": "res4b2", "res9": "res4b3", "res10": "res4b4", "res11": "res4b5", "res12": "res5a", "res13": "res5b1", "res14": "res5b2", "res15": "res6a", "res16": "res7a", "output": "linear150", "conv0": "conv1a", "conv1": ["bn7", "conv6a"] } elif model_name.startswith("voc"): layer_mapping = { "res0": "res2a", "res1": "res2b1", "res2": "res2b2", "res3": "res3a", "res4": "res3b1", "res5": "res3b2", "res6": "res4a", "res7": "res4b1", "res8": "res4b2", "res9": "res4b3", "res10": "res4b4", "res11": "res4b5", "res12": "res5a", "res13": "res5b1", "res14": "res5b2", "res15": "res6a", "res16": "res7a", "output": "linear21", "conv0": "conv1a", "conv1": ["bn7", "conv6a"] } else: assert False, model_name # from str (without :0) to var var_dict = { v.name[:-2]: v for v in variables if "Adam" not in v.name and "_power" not in v.name and "global_step" not in v.name } # from our var name to mxnet var name mxnet_dict = create_mxnet_dict(layer_mapping, var_dict) for k, v in mxnet_dict.items(): assert v in params, (k, v) # use a placeholder to avoid memory issues placeholder = tf.placeholder(tf.float32) for k, v in mxnet_dict.items(): val = params[v] if val.ndim == 1: pass elif val.ndim == 2: val = numpy.swapaxes(val, 0, 1) elif val.ndim == 4: val = numpy.moveaxis(val, [0, 1, 2, 3], [3, 2, 0, 1]) else: assert False, val.ndim var = var_dict[k] if var.get_shape() == val.shape: op = tf.assign(var, placeholder) session.run([op], feed_dict={placeholder: val}) elif k.startswith("conv0"): logger.info("warning, sizes for", k, "do not match, initializing matching part assuming", "the first 3 dimensions are RGB") val_new = session.run(var) val_new[..., :3, :] = val op = tf.assign(var, placeholder) session.run([op], feed_dict={placeholder: val_new}) else: logger.info("skipping", k, "since the shapes do not match:", var.get_shape(), "and", val.shape)