Beispiel #1
0
def create_network(input_shape, setup_dir):

	print ("MKNET: PROB-UNET SAMPLE")
	print("")
	tf.reset_default_graph()

	raw = tf.placeholder(tf.float32, shape=input_shape, name="raw") # for gp
	raw_batched = tf.reshape(raw, (1, 1) + input_shape, name="raw_batched") # for tf

	print ("raw_batched: ", raw_batched.shape)
	print ("")

	unet = UNet(
		fmaps_in = raw_batched,
		num_layers = 3,
		base_channels = 12,
		channel_inc_factor = 3,
		resample_factors = [[2,2,2], [2,2,2], [2,2,2]],
		padding_type = "valid",
		num_conv_passes = 2,
		down_kernel_size = [3, 3, 3],
		up_kernel_size = [3, 3, 3],
		activation_type = tf.nn.relu,
		downsample_type = "max_pool",
		upsample_type = "conv_transpose",
		voxel_size = (1, 1, 1))
	unet.build()
	print("")

	prior = Encoder(
		fmaps_in = raw_batched,
		affmaps_in = None,
		num_layers = 3,
		latent_dims = 6,
		base_channels = 12,
		channel_inc_factor = 3,
		downsample_factors = [[2,2,2], [2,2,2], [2,2,2]],
		padding_type = "valid",
		num_conv_passes = 2,
		down_kernel_size = [3, 3, 3],
		activation_type = tf.nn.relu,
		downsample_type = "max_pool",
		voxel_size = (1, 1, 1),
		name = "prior")
	prior.build()
	print ("")

	f_comb = FComb(
		fmaps_in = unet.get_fmaps(),
		sample_in = prior.sample(),
		num_1x1_convs = 3,
		num_channels = 12,
		padding_type = 'valid',
		activation_type = tf.nn.relu,
		voxel_size = (1, 1, 1))
	f_comb.build()
	print ("")

	pred_logits = tf.layers.conv3d(
		inputs=f_comb.get_fmaps(),
		filters=3, 
		kernel_size=1,
		padding='valid',
		data_format="channels_first",
		activation=None,
		name="affs")
	print ("")

	pred_affs = tf.sigmoid(pred_logits)

	output_shape_batched = pred_logits.get_shape().as_list()
	output_shape = output_shape_batched[1:] # strip the batch dimension

	pred_logits = tf.squeeze(pred_logits, axis=[0], name="pred_logits")
	pred_affs = tf.squeeze(pred_affs, axis=[0], name="pred_affs")
	
	# sample_z = tf.squeeze(prior.sample(), axis=[0], name="sample_z")
	sample_z = prior.sample()
	sample_z_batched = tf.reshape(sample_z, (1, 1, 6), name="sample_z") # for tf
	print("sample_z", sample_z_batched.shape)

	print ("pred_logits: ", pred_logits.shape)
	print ("pred_affs: ", pred_affs.shape)

	output_shape = output_shape[1:]
	print("input shape : %s"%(input_shape,))
	print("output shape: %s"%(output_shape,))

	tf.train.export_meta_graph(filename=setup_dir + 'predict_net.meta')

	config = {
		'raw': raw.name,
		'pred_affs': pred_affs.name,
		'input_shape': input_shape,
		'output_shape': output_shape,
		'sample_z': sample_z_batched.name
	}
	with open(setup_dir + 'predict_config.json', 'w') as f:
		json.dump(config, f)
def create_network(input_shape, setup_dir):

    print("MKNET: PROB-UNET SAMPLE RUN")
    print("")
    tf.reset_default_graph()

    raw = tf.placeholder(tf.float32, shape=input_shape)  # for gp
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)  # for tf

    print("raw_batched: ", raw_batched.shape)
    print("")

    # with tf.variable_scope("debug") as dbscope:
    # 	debug_batched = tf.constant([[[1,2,3,4,5]]])
    # 	# debug_batched = tf.reshape(debug, (1,1,5))
    # 	print('DEBUG:', debug_batched.name)

    with tf.variable_scope("unet") as vs1:
        unet = UNet(fmaps_in=raw_batched,
                    num_layers=3,
                    base_channels=12,
                    channel_inc_factor=3,
                    resample_factors=[[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                    padding_type="valid",
                    num_conv_passes=2,
                    down_kernel_size=[3, 3, 3],
                    up_kernel_size=[3, 3, 3],
                    activation_type=tf.nn.relu,
                    downsample_type="max_pool",
                    upsample_type="conv_transpose",
                    voxel_size=(1, 1, 1))
        unet.build()
        print("")

    with tf.variable_scope("prior") as vs2:
        prior = Encoder(fmaps_in=raw_batched,
                        affmaps_in=None,
                        num_layers=3,
                        latent_dims=6,
                        base_channels=12,
                        channel_inc_factor=3,
                        downsample_factors=[[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                        padding_type="valid",
                        num_conv_passes=2,
                        down_kernel_size=[3, 3, 3],
                        activation_type=tf.nn.relu,
                        downsample_type="max_pool",
                        voxel_size=(1, 1, 1),
                        name="prior")
        prior.build()
        print("")

    sample_z = prior.sample()
    sample_z_batched = tf.reshape(sample_z, (1, 1, 6))

    with tf.variable_scope("f_comb") as vs3:
        f_comb = FComb(fmaps_in=unet.get_fmaps(),
                       sample_in=sample_z,
                       num_1x1_convs=3,
                       num_channels=12,
                       padding_type='valid',
                       activation_type=tf.nn.relu,
                       voxel_size=(1, 1, 1))
        f_comb.build()
        print("")

        # broadcast_sample = f_comb.out
        # sample_out = f_comb.sample_out
        print("sample_out_name: ", f_comb.get_fmaps().name)
        # sample_out_batched = tf.reshape(sample_out, (1, 1, 6))
        # print("sample_out: ", sample_out_batched.shape)

    with tf.variable_scope("affs") as vs4:
        pred_logits = tf.layers.conv3d(inputs=f_comb.get_fmaps(),
                                       filters=3,
                                       kernel_size=1,
                                       padding='valid',
                                       data_format="channels_first",
                                       activation=None,
                                       name="affs")
        print("")

    # print("broadcast_sample: ", broadcast_sample)

    pred_affs = tf.sigmoid(pred_logits)

    output_shape_batched = pred_logits.get_shape().as_list()
    output_shape = output_shape_batched[1:]  # strip the batch dimension

    pred_logits = tf.squeeze(pred_logits, axis=[0])
    pred_affs = tf.squeeze(pred_affs, axis=[0])

    print("pred_logits: ", pred_logits.shape)
    print("pred_affs: ", pred_affs.shape)

    output_shape = output_shape[1:]
    print("input shape : %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    tf.train.export_meta_graph(filename=setup_dir + 'predict_net.meta')

    config = {
        'raw': raw.name,
        'pred_affs': pred_affs.name,
        'input_shape': input_shape,
        'output_shape': output_shape,
        # 'broadcast': broadcast_sample.name,
        # 'sample_z': sample_z_batched.name,
        # 'pred_logits': pred_logits.name,
        # 'sample_out': sample_out_batched.name,
        # 'debug': debug_batched.name
    }
    with open(setup_dir + 'predict_config.json', 'w') as f:
        json.dump(config, f)
Beispiel #3
0
def create_network(input_shape, name):

	print ("MKNET: PROB-UNET TRAIN")
	print("")
	tf.reset_default_graph()

	raw = tf.placeholder(tf.float32, shape=input_shape, name="raw") # for gp
	raw_batched = tf.reshape(raw, (1, 1) + input_shape, name="raw_batched") # for tf
	gt_affs_in = tf.placeholder(tf.float32, shape = (3,) + input_shape, name="gt_affs_in")
	gt_affs_in_batched = tf.reshape(gt_affs_in, (1, 3) + input_shape, name="gt_affs_in_batched")

	print ("raw_batched: ", raw_batched.shape)
	print ("gt_affs_in_batched: ", gt_affs_in_batched.shape)
	print ("")

	unet = UNet(
		fmaps_in = raw_batched,
		num_layers = 3,
		base_channels = 12,
		channel_inc_factor = 3,
		resample_factors = [[2,2,2], [2,2,2], [2,2,2]],
		padding_type = "valid",
		num_conv_passes = 2,
		down_kernel_size = [3, 3, 3],
		up_kernel_size = [3, 3, 3],
		activation_type = tf.nn.relu,
		downsample_type = "max_pool",
		upsample_type = "conv_transpose",
		voxel_size = (1, 1, 1))
	unet.build()
	print("")

	prior = Encoder(
		fmaps_in = raw_batched,
		affmaps_in = None,
		num_layers = 3,
		latent_dims = 6,
		base_channels = 12,
		channel_inc_factor = 3,
		downsample_factors = [[2,2,2], [2,2,2], [2,2,2]],
		padding_type = "valid",
		num_conv_passes = 2,
		down_kernel_size = [3, 3, 3],
		activation_type = tf.nn.relu,
		downsample_type = "max_pool",
		voxel_size = (1, 1, 1),
		name = "prior")
	prior.build()
	print ("")

	posterior = Encoder(
		fmaps_in = raw_batched,
		affmaps_in = gt_affs_in_batched,
		num_layers = 3,
		latent_dims = 6,
		base_channels = 12,
		channel_inc_factor = 3,
		downsample_factors = [[2,2,2], [2,2,2], [2,2,2]],
		padding_type = "valid",
		num_conv_passes = 2,
		down_kernel_size = [3, 3, 3],
		activation_type = tf.nn.relu,
		downsample_type = "max_pool",
		voxel_size = (1, 1, 1),
		name = "posterior")
	posterior.build()
	print ("")

	f_comb = FComb(
		fmaps_in = unet.get_fmaps(),
		sample_in = posterior.sample(),
		num_1x1_convs = 3,
		num_channels = 12,
		padding_type = 'valid',
		activation_type = tf.nn.relu,
		voxel_size = (1, 1, 1))
	f_comb.build()

	pred_logits = tf.layers.conv3d(
		inputs=f_comb.get_fmaps(),
		filters=3, 
		kernel_size=1,
		padding='valid',
		data_format="channels_first",
		activation=None,
		name="affs")
	print ("")

	pred_affs = tf.sigmoid(pred_logits)

	output_shape_batched = pred_logits.get_shape().as_list()
	output_shape = output_shape_batched[1:] # strip the batch dimension

	pred_logits = tf.squeeze(pred_logits, axis=[0], name="pred_logits")
	pred_affs = tf.squeeze(pred_affs, axis=[0], name="pred_affs")

	gt_affs_out = tf.placeholder(tf.float32, shape=output_shape, name="gt_affs_out")
	pred_affs_loss_weights = tf.placeholder(tf.float32, shape=output_shape, name="pred_affs_loss_weights")

	# neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]
	# gt_seg = tf.placeholder(tf.int64, shape=output_shape, name='gt_seg')
	

	print ("gt_affs_out: ", gt_affs_out.shape)
	print ("pred_logits: ", pred_logits.shape)
	print ("pred_affs: ", pred_affs.shape)
	# print ("gt_seg: ", gt_seg.shape)
	print ("")

	print ("prior: ", prior.get_fmaps())

	# p = prior.get_distrib()
	# q = posterior.get_distrib()

	# tf.add_to_collection(sample_p, sample_p)
	# tf.add_to_collection("sample_q", sample_q)

	# kl_loss = tf.distributions.kl_divergence(p, q)
	# kl_loss = tf.reshape(kl_loss, [], name="kl_loss")

	# mse_loss = tf.losses.mean_squared_error(gt_affs_out, pred_affs, pred_affs_loss_weights)

	# sce_loss = tf.losses.sigmoid_cross_entropy(
	# 	multi_class_labels = gt_affs_out,
	# 	logits = pred_logits,
	# 	weights = pred_affs_loss_weights)


	# mlo_loss = malis.malis_loss_op(
	# 	pred_affs, 
	# 	gt_affs_out, 
	# 	gt_seg,
	# 	neighborhood)
	
	# loss = sce_loss + beta * kl_loss
	# loss = mlo_loss + beta * kl_loss
	# summary = tf.summary.scalar('loss', loss)

	# tf.summary.scalar('kl_loss', kl_loss)
	# tf.summary.scalar('mlo_loss', mlo_loss)
	# summary = tf.summary.merge_all()

	# opt = tf.train.AdamOptimizer(
	# 	learning_rate=1e-6,
	# 	beta1=0.95,
	# 	beta2=0.999,
	# 	epsilon=1e-8)
	# opt = tf.train.AdamOptimizer()
	# optimizer = opt.minimize(loss)

	output_shape = output_shape[1:]
	print("input shape : %s" % (input_shape,))
	print("output shape: %s" % (output_shape,))

	tf.train.export_meta_graph(filename=name + '.meta')



	config = {
		'raw': raw.name,
		'pred_affs': pred_affs.name,
		'gt_affs_in': gt_affs_in.name,
		'gt_affs_out': gt_affs_out.name,
		'pred_affs_loss_weights': pred_affs_loss_weights.name,
		# 'kl_loss': kl_loss.name,
		# 'optimizer': optimizer.name,
		'input_shape': input_shape,
		'output_shape': output_shape,
		'prior': prior.get_fmaps().name,
		'posterior': posterior.get_fmaps().name,
		'latent_dims': 6
		# 'summary': summary.name,
	}
	with open(name + '.json', 'w') as f:
		json.dump(config, f)
def create_network(input_shape, setup_dir):

    print("MKNET: PROB-UNET TRAIN")
    print("")
    tf.reset_default_graph()

    raw = tf.placeholder(tf.float32, shape=input_shape)  # for gp
    raw_batched = tf.reshape(raw, (1, 1) + input_shape)  # for tf
    gt_affs_in = tf.placeholder(tf.float32, shape=(3, ) + input_shape)
    gt_affs_in_batched = tf.reshape(gt_affs_in, (1, 3) + input_shape)

    print("raw_batched: ", raw_batched.shape)
    print("gt_affs_in_batched: ", gt_affs_in_batched.shape)
    print("")

    with tf.variable_scope("debug") as dbscope:
        debug_batched = tf.constant([[[1, 2, 3, 4, 5]]])
        # debug_batched = tf.reshape(debug, (1,1,5))
        print('DEBUG:', debug_batched.name)

    with tf.variable_scope("unet") as vs1:
        unet = UNet(fmaps_in=raw_batched,
                    num_layers=3,
                    base_channels=12,
                    channel_inc_factor=3,
                    resample_factors=[[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                    padding_type="valid",
                    num_conv_passes=2,
                    down_kernel_size=[3, 3, 3],
                    up_kernel_size=[3, 3, 3],
                    activation_type=tf.nn.relu,
                    downsample_type="max_pool",
                    upsample_type="conv_transpose",
                    voxel_size=(1, 1, 1))
        unet.build()
        print("")

    with tf.variable_scope("prior") as vs2:
        prior = Encoder(fmaps_in=raw_batched,
                        affmaps_in=None,
                        num_layers=3,
                        latent_dims=6,
                        base_channels=12,
                        channel_inc_factor=3,
                        downsample_factors=[[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                        padding_type="valid",
                        num_conv_passes=2,
                        down_kernel_size=[3, 3, 3],
                        activation_type=tf.nn.relu,
                        downsample_type="max_pool",
                        voxel_size=(1, 1, 1),
                        name="prior")
        prior.build()
        print("")

    with tf.variable_scope("posterior") as vs3:
        posterior = Encoder(fmaps_in=raw_batched,
                            affmaps_in=gt_affs_in_batched,
                            num_layers=3,
                            latent_dims=6,
                            base_channels=12,
                            channel_inc_factor=3,
                            downsample_factors=[[2, 2, 2], [2, 2, 2],
                                                [2, 2, 2]],
                            padding_type="valid",
                            num_conv_passes=2,
                            down_kernel_size=[3, 3, 3],
                            activation_type=tf.nn.relu,
                            downsample_type="max_pool",
                            voxel_size=(1, 1, 1),
                            name="posterior")
        posterior.build()
        print("")

        sample_z = posterior.sample()
        sample_z_batched = tf.reshape(sample_z, (1, 1, 6))

    with tf.variable_scope("f_comb") as vs4:
        f_comb = FComb(fmaps_in=unet.get_fmaps(),
                       sample_in=sample_z,
                       num_1x1_convs=3,
                       num_channels=12,
                       padding_type='valid',
                       activation_type=tf.nn.relu,
                       voxel_size=(1, 1, 1))
        f_comb.build()
        print("")

        broadcast_sample = f_comb.broadcast_sample
        sample_out = f_comb.sample_out
        print("sample_out_name: ", f_comb.get_fmaps().name)
        sample_out_batched = tf.reshape(sample_out, (1, 1, 6))

    with tf.variable_scope("affs") as vs5:
        pred_logits = tf.layers.conv3d(inputs=f_comb.get_fmaps(),
                                       filters=3,
                                       kernel_size=1,
                                       padding='valid',
                                       data_format="channels_first",
                                       activation=None,
                                       name="affs")
        print("")

    pred_affs = tf.sigmoid(pred_logits)

    output_shape_batched = pred_logits.get_shape().as_list()
    output_shape = output_shape_batched[1:]  # strip the batch dimension

    pred_logits = tf.squeeze(pred_logits, axis=[0])
    pred_affs = tf.squeeze(pred_affs, axis=[0])

    gt_affs_out = tf.placeholder(tf.float32, shape=output_shape)
    pred_affs_loss_weights = tf.placeholder(tf.float32, shape=output_shape)

    print("gt_affs_out: ", gt_affs_out.shape)
    print("pred_logits: ", pred_logits.shape)
    print("pred_affs: ", pred_affs.shape)
    print("")

    print("gt_affs_out: ", gt_affs_out)
    print("pred_logits: ", pred_logits)
    print("pred_affs: ", pred_affs)

    # sce_loss = tf.losses.mean_squared_error(
    # 	gt_affs_out,
    # 	pred_logits,
    # 	pred_affs_loss_weights)

    sce_loss = tf.losses.sigmoid_cross_entropy(gt_affs_out, pred_logits)

    # sce_loss = tf.losses.log_loss(
    # 	gt_affs_out,
    # 	pred_affs,
    # 	pred_affs_loss_weights)

    summary = tf.summary.scalar('sce_loss', sce_loss)
    # summary = tf.summary.merge_all()

    # opt = tf.train.AdamOptimizer(
    # 	learning_rate=1e-6,
    # 	beta1=0.95,
    # 	beta2=0.999,
    # 	epsilon=1e-8)
    opt = tf.train.AdamOptimizer()
    optimizer = opt.minimize(sce_loss)

    output_shape = output_shape[1:]
    print("input shape : %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    tf.train.export_meta_graph(filename=setup_dir + "train_net.meta")

    config = {
        'raw': raw.name,
        'pred_affs': pred_affs.name,
        'gt_affs_in': gt_affs_in.name,
        'gt_affs_out': gt_affs_out.name,
        'pred_affs_loss_weights': pred_affs_loss_weights.name,
        'loss': sce_loss.name,
        'optimizer': optimizer.name,
        'input_shape': input_shape,
        'output_shape': output_shape,
        'prior': prior.get_fmaps().name,
        'posterior': posterior.get_fmaps().name,
        'latent_dims': 12,
        'summary': summary.name,
        'broadcast': broadcast_sample.name,
        'sample_z': sample_z_batched.name,
        'pred_logits': pred_logits.name,
        'sample_out': sample_out_batched.name,
        'debug': debug_batched.name
    }
    with open(setup_dir + 'train_config.json', 'w') as f:
        json.dump(config, f)
def create_network(input_shape, setup_dir):

    print("MKNET: PROB-UNET TRAIN")
    print("")
    tf.reset_default_graph()

    raw = tf.placeholder(tf.float32, shape=input_shape, name="raw")  # for gp
    raw_batched = tf.reshape(raw, (1, 1) + input_shape,
                             name="raw_batched")  # for tf
    gt_affs_in = tf.placeholder(tf.float32,
                                shape=(3, ) + input_shape,
                                name="gt_affs_full")
    gt_affs_in_batched = tf.reshape(gt_affs_in, (1, 3) + input_shape,
                                    name="gt_affs_full_batched")

    print("raw_batched: ", raw_batched.shape)
    print("gt_affs_in_batched: ", gt_affs_in_batched.shape)
    print("")

    unet = UNet(fmaps_in=raw_batched,
                num_layers=3,
                base_channels=12,
                channel_inc_factor=3,
                resample_factors=[[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                padding_type="valid",
                num_conv_passes=2,
                down_kernel_size=[3, 3, 3],
                up_kernel_size=[3, 3, 3],
                activation_type=tf.nn.relu,
                downsample_type="max_pool",
                upsample_type="conv_transpose",
                voxel_size=(1, 1, 1))
    unet.build()
    print("")

    prior = Encoder(fmaps_in=raw_batched,
                    affmaps_in=None,
                    num_layers=3,
                    latent_dims=6,
                    base_channels=12,
                    channel_inc_factor=3,
                    downsample_factors=[[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                    padding_type="valid",
                    num_conv_passes=2,
                    down_kernel_size=[3, 3, 3],
                    activation_type=tf.nn.relu,
                    downsample_type="max_pool",
                    voxel_size=(1, 1, 1),
                    name="prior")
    prior.build()
    print("")

    posterior = Encoder(fmaps_in=raw_batched,
                        affmaps_in=gt_affs_in_batched,
                        num_layers=3,
                        latent_dims=6,
                        base_channels=12,
                        channel_inc_factor=3,
                        downsample_factors=[[2, 2, 2], [2, 2, 2], [2, 2, 2]],
                        padding_type="valid",
                        num_conv_passes=2,
                        down_kernel_size=[3, 3, 3],
                        activation_type=tf.nn.relu,
                        downsample_type="max_pool",
                        voxel_size=(1, 1, 1),
                        name="posterior")
    posterior.build()
    print("")

    f_comb = FComb(fmaps_in=unet.get_fmaps(),
                   sample_in=z,
                   num_1x1_convs=3,
                   num_channels=12,
                   padding_type='valid',
                   activation_type=tf.nn.relu,
                   voxel_size=(1, 1, 1))
    f_comb.build()

    pred_logits = tf.layers.conv3d(inputs=f_comb.get_fmaps(),
                                   filters=3,
                                   kernel_size=1,
                                   padding='valid',
                                   data_format="channels_first",
                                   activation=None,
                                   name="affs")
    print("")

    pred_affs = tf.sigmoid(pred_logits)

    output_shape_batched = pred_logits.get_shape().as_list()
    output_shape = output_shape_batched[1:]  # strip the batch dimension

    pred_logits = tf.squeeze(pred_logits, axis=[0], name="pred_logits")
    pred_affs = tf.squeeze(pred_affs, axis=[0], name="pred_affs")

    gt_affs_out = tf.placeholder(tf.float32,
                                 shape=output_shape,
                                 name="gt_affs")
    pred_affs_loss_weights = tf.placeholder(tf.float32,
                                            shape=output_shape,
                                            name="pred_affs_loss_weights")

    print("gt_affs_out: ", gt_affs_out.shape)
    print("pred_logits: ", pred_logits.shape)
    print("pred_affs: ", pred_affs.shape)
    print("")

    mse_loss = tf.losses.mean_squared_error(gt_affs_out, pred_affs,
                                            pred_affs_loss_weights)

    # sce_loss = tf.losses.sigmoid_cross_entropy(
    # 	multi_class_labels = gt_affs_out,
    # 	logits = pred_logits,
    # 	weights = pred_affs_loss_weights)

    summary = tf.summary.scalar('mse_loss', mse_loss)
    # summary = tf.summary.merge_all()

    # opt = tf.train.AdamOptimizer(
    # 	learning_rate=1e-6,
    # 	beta1=0.95,
    # 	beta2=0.999,
    # 	epsilon=1e-8)
    opt = tf.train.AdamOptimizer()
    optimizer = opt.minimize(mse_loss)

    output_shape = output_shape[1:]
    print("input shape : %s" % (input_shape, ))
    print("output shape: %s" % (output_shape, ))

    tf.train.export_meta_graph(filename=setup_dir + "train_net.meta")

    config = {
        'raw': raw.name,
        'pred_affs': pred_affs.name,
        'gt_affs_in': gt_affs_in.name,
        'gt_affs_out': gt_affs_out.name,
        'pred_affs_loss_weights': pred_affs_loss_weights.name,
        'loss': mse_loss.name,
        'optimizer': optimizer.name,
        'input_shape': input_shape,
        'output_shape': output_shape,
        'prior': prior.get_fmaps().name,
        'posterior': posterior.get_fmaps().name,
        'latent_dims': 6,
        'summary': summary.name,
    }
    with open(setup_dir + 'train_config.json', 'w') as f:
        json.dump(config, f)