def build_network(*args, input_dim=3, p0_z=None, z_dim=128, beta=None, skip_connection=True, variational=False, use_kl=False, geo_initial=True): net = Network(*args, input_dim=input_dim, p0_z=p0_z, z_dim=z_dim, beta=beta, skip_connection=skip_connection, variational=variational, use_kl=use_kl) if geo_initial: print("Perform geometric initialization!\n") for k, v in net.named_parameters(): if 'encoder' in k: pass else: if 'weight' in k: std = np.sqrt(2) / np.sqrt(v.shape[0]) nn.init.normal_(v, 0.0, std) if 'bias' in k: nn.init.constant_(v, 0) if 'l_out.weight' in k: std = np.sqrt(np.pi) / np.sqrt(v.shape[1]) nn.init.constant_(v, std) if 'l_out.bias' in k: nn.init.constant_(v, -0.5) return net