def bn_post_process(model_file: str, save_model_file: str, data): with TrainingEnv(name=model_file + "bn_post_proc", part_count=2) as env: net = load_network(open(model_file, "rb")) #loss_func = env.make_func_from_loss_var(net.loss_var, "val", train_state = False) bn_oprs = [ opr for opr in net.loss_visitor.all_oprs if isinstance(opr, BatchNormalization) ] bn_inputs = [opr.inputs[0] for opr in bn_oprs] mean_Esqr_nodes = [] for i in bn_inputs: if i.partial_shape.ndim == 2: mean = i.mean(axis=0).reshape((1, -1)) mean.vflags.data_parallel_reduce_method = 'sum' Esqr = (i**2).mean(axis=0).reshape((1, -1)) Esqr.vflags.data_parallel_reduce_method = 'sum' if i.partial_shape.ndim == 4: mean = i.mean(axis=3).mean(axis=2).mean(axis=0).reshape( (1, -1)) mean.vflags.data_parallel_reduce_method = 'sum' Esqr = (i**2).mean(axis=3).mean(axis=2).mean(axis=0).reshape( (1, -1)) Esqr.vflags.data_parallel_reduce_method = 'sum' mean_Esqr_nodes.append(mean) mean_Esqr_nodes.append(Esqr) func = Function().compile(mean_Esqr_nodes) for i in range(len(bn_oprs)): opr = bn_oprs[i] layer_mean, layer_var = _get_dataset_mean_var(data, func, i) if layer_mean.ndim == 0: layer_mean = layer_mean.reshape((1, )) if layer_var.ndim == 0: layer_var = layer_var.reshape((1, )) state = opr.State(channels=layer_mean.shape[0], val=[layer_mean, layer_var, 1]) state.owner_opr_type = type(opr) opr.set_opr_state(state) opr.freezed = True env.register_checkpoint_component("network", net) env.save_checkpoint(save_model_file)
data = [] labels = [] for i in range(size): #a = p.get() #(img, label) = msgpack.unpackb(a, object_hook = m.decode) (img, label) = p.get() data.append(img) labels.append(label) return { "data": np.array(data).astype(np.float32), "label": np.array(labels) } if __name__ == '__main__': with TrainingEnv(name="lyy.{}.test".format(net_name), part_count=2) as env: net = make_network(minibatch_size=minibatch_size) preloss = net.loss_var net.loss_var = WeightDecay(net.loss_var, { "*conv*:W": 1e-4, "*fc*:W": 1e-4, "*bnaff*:k": 1e-4, "*offset*": 1e-4 }) train_func = env.make_func_from_loss_var(net.loss_var, "train", train_state=True) valid_func = env.make_func_from_loss_var(net.loss_var, "val", train_state=False)
def get_minibatch(p, size): data = [] labels = [] for i in range(size): (img, label) = p.get() data.append(img) labels.append(label) return {"data": np.array(data).astype(np.float32), "label":np.array(labels)} if __name__ == '__main__': parser = argparse.ArgumentParser() os.system("rm -r tbdata/") tb = TB("tbdata/") with TrainingEnv(name = "lyy.{}.test".format(net_name), part_count = 2, custom_parser = parser) as env: args = parser.parse_args() num_GPU = len(args.devices.split(',')) minibatch_size *= num_GPU net = make_network(minibatch_size = minibatch_size) preloss = net.loss_var net.loss_var = WeightDecay(net.loss_var, {"*conv*": 1e-4, "*fc*": 1e-4, "*bnaff*:k": 1e-4, "*offset*":1e-4}) train_func = env.make_func_from_loss_var(net.loss_var, "train", train_state = True) lr = 0.1 * num_GPU optimizer = Momentum(lr, 0.9) optimizer(train_func) #train_func.comp_graph.share_device_memory_with(valid_func.comp_graph)
def get_minibatch(p, size): data = [] labels = [] for i in range(size): #a = p.get() #(img, label) = msgpack.unpackb(a, object_hook = m.decode) (img, label) = p.get() data.append(img) labels.append(label) return { "data": np.array(data).astype(np.float32), "label": np.array(labels) } with TrainingEnv(name="lyy.resnet20.test", part_count=2) as env: net = make_network(minibatch_size=minibatch_size) preloss = net.loss_var net.loss_var = WeightDecay(net.loss_var, { "*conv*:W": 1e-4, "*fc*:W": 1e-4, "*bnaff*:k": 1e-4 }) train_func = env.make_func_from_loss_var(net.loss_var, "train", train_state=True) valid_func = env.make_func_from_loss_var(net.loss_var, "val", train_state=False)
def get_minibatch(p, size): data = [] labels = [] for i in range(size): #a = p.get() #(img, label) = msgpack.unpackb(a, object_hook = m.decode) (img, label) = p.get() data.append(img) labels.append(label) return {"data": np.array(data).astype(np.float32), "label":np.array(labels)} if __name__ == '__main__': parser = argparse.ArgumentParser() with TrainingEnv(name = "lyy.resnet20.test", part_count = 2, custom_parser = parser) as env: net = make_network(minibatch_size = minibatch_size) preloss = net.loss_var net.loss_var = WeightDecay(net.loss_var, {"*conv*": 1e-4, "*fc*": 1e-4}) train_func = env.make_func_from_loss_var(net.loss_var, "train", train_state = True) valid_func = env.make_func_from_loss_var(net.loss_var, "val", train_state = False) warmup = False if warmup: lr = 0.01 else: lr = 0.1 optimizer = megskull.optimizer.Momentum(lr, 0.9) #optimizer.learning_rate = 0.01 optimizer(train_func)