def visualizeProductDistribution3(sess, input_dict, batch, obs_dist, transport_dist, rec_dist, sample_obs_dist, save_dir = '.', postfix = ''): sample = batch['observed']['data'] sample_properties = batch['observed']['properties'] for obs_type in ['flat', 'image']: if sample[obs_type] is not None: sample_split = helper.split_tensor_np(sample[obs_type], -1, [e['size'][-1] for e in obs_dist.sample_properties[obs_type]]) param_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(obs_dist.dist_list[obs_type])] transport_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(transport_dist.dist_list[obs_type])] rec_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(rec_dist.dist_list[obs_type])] rand_param_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(sample_obs_dist.dist_list[obs_type])] param_split, transport_split, rec_split, rand_param_split = sess.run([param_split_tf, transport_split_tf, rec_split_tf, rand_param_split_tf], feed_dict = input_dict) rand_param_split2 = None while rand_param_split2 is None or rand_param_split2.shape[0]<300: if rand_param_split2 is None: rand_param_split2 = sess.run(rand_param_split_tf, feed_dict = input_dict)[0] else: rand_param_split2 = np.concatenate([rand_param_split2, sess.run(rand_param_split_tf, feed_dict = input_dict)[0]], axis=0) samples_params_np = np.array([np.array([]), *sample_split, *transport_split, *rec_split, *param_split, *rand_param_split])[1:] # rand_param_split2 (300, 1, 64, 64, 3) if obs_type == 'flat': cont_var_filter = np.tile(np.asarray([e['dist'] == 'cont' for e in batch['observed']['properties'][obs_type]]), 4) not_cont_var_filter = np.tile(np.asarray([e['dist'] != 'cont' for e in batch['observed']['properties'][obs_type]]), 4) if sum(not_cont_var_filter) > 0: helper.visualize_flat(samples_params_np[not_cont_var_filter], save_dir = save_dir, postfix = postfix+'_'+obs_type) if sum(cont_var_filter) > 0: helper.visualize_vectors(samples_params_np[cont_var_filter], save_dir = save_dir, postfix = postfix+'_'+obs_type) if obs_type == 'image': samples_params_np = np.array([np.array([]), *sample_split, *transport_split, *rec_split, *param_split, *rand_param_split])[1:] samples_params_np_interleaved = helper.interleave_data(samples_params_np) helper.visualize_images2(samples_params_np_interleaved, block_size=[sample_split[0].shape[0], len(samples_params_np)], save_dir=save_dir+'normal/', postfix=postfix+'_'+obs_type+'_normal') helper.visualize_images2(rand_param_split2[:int(np.sqrt(rand_param_split2.shape[0]))**2, ...], block_size=[int(np.sqrt(rand_param_split2.shape[0])), int(np.sqrt(rand_param_split2.shape[0]))], save_dir=save_dir+'normal_sample_only/', postfix=postfix+'_'+obs_type+'_normal_sample_only')
def visualizeProductDistribution(sess, input_dict, batch, obs_dist, sample_obs_dist, save_dir = '.', postfix = ''): sample = batch['observed']['data'] sample_properties = batch['observed']['properties'] for obs_type in ['flat', 'image']: if sample[obs_type] is not None: sample_split = helper.split_tensor_np(sample[obs_type], -1, [e['size'][-1] for e in obs_dist.sample_properties[obs_type]]) if 'param_split_tf' not in input_dict: param_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(obs_dist.dist_list[obs_type])] input_dict['param_split_tf'] = param_split_tf if 'rand_param_split_tf' not in input_dict: rand_param_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(sample_obs_dist.dist_list[obs_type])] input_dict['rand_param_split_tf'] = rand_param_split_tf pdb.set_trace() param_split_tf = input_dict['param_split_tf'] rand_param_split_tf = input_dict['rand_param_split_tf'] param_split = sess.run(param_split_tf, feed_dict = input_dict) rand_param_split = sess.run(rand_param_split_tf, feed_dict = input_dict) samples_params_np = np.array([np.array([]), *sample_split, *param_split, *rand_param_split])[1:] if obs_type == 'flat': cont_var_filter = np.tile(np.asarray([e['dist'] == 'cont' for e in batch['observed']['properties'][obs_type]]), 3) not_cont_var_filter = np.tile(np.asarray([e['dist'] != 'cont' for e in batch['observed']['properties'][obs_type]]), 3) if sum(not_cont_var_filter) > 0: helper.visualize_flat(samples_params_np[not_cont_var_filter], save_dir = save_dir, postfix = postfix+'_'+obs_type) if sum(cont_var_filter) > 0: helper.visualize_vectors(samples_params_np[cont_var_filter], save_dir = save_dir, postfix = postfix+'_'+obs_type) if obs_type == 'image': samples_params_np_interleaved = helper.interleave_data(samples_params_np) helper.visualize_images(samples_params_np_interleaved, save_dir = save_dir, postfix = postfix+'_'+obs_type)
def visualizeProductDistribution4(sess, model, input_dict, batch, real_dist, transport_dist, reg_target_dist, rec_dist, obs_dist, sample_obs_dist, real_data = None, save_dir = '.', postfix = '', postfix2 = None, b_zero_one_range=True): sample = batch['observed']['data'] sample_properties = batch['observed']['properties'] for obs_type in ['flat', 'image']: if sample[obs_type] is not None: sample_split = helper.split_tensor_np(sample[obs_type], -1, [e['size'][-1] for e in obs_dist.sample_properties[obs_type]]) if not hasattr(model, 'real_split_tf'): model.real_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(real_dist.dist_list[obs_type])] if not hasattr(model, 'transport_split_tf'): model.transport_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(transport_dist.dist_list[obs_type])] if not hasattr(model, 'reg_target_split_tf'): model.reg_target_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(reg_target_dist.dist_list[obs_type])] if not hasattr(model, 'rec_split_tf'): model.rec_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(rec_dist.dist_list[obs_type])] if not hasattr(model, 'param_split_tf'): model.param_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(obs_dist.dist_list[obs_type])] if not hasattr(model, 'rand_param_split_tf'): model.rand_param_split_tf = [tf.reshape(e.get_interpretable_params()[0], list(sample_split[i].shape)) for i, e in enumerate(sample_obs_dist.dist_list[obs_type])] real_split, transport_split, reg_target_split, rec_split, param_split, rand_param_split = sess.run([model.real_split_tf, model.transport_split_tf, model.reg_target_split_tf, model.rec_split_tf, model.param_split_tf, model.rand_param_split_tf], feed_dict = input_dict) rand_param_split2 = None while rand_param_split2 is None or rand_param_split2.shape[0]<400: if rand_param_split2 is None: rand_param_split2 = sess.run(model.rand_param_split_tf, feed_dict = input_dict)[0] else: rand_param_split2 = np.concatenate([rand_param_split2, sess.run(model.rand_param_split_tf, feed_dict = input_dict)[0]], axis=0) samples_params_np = np.array([np.array([]), *sample_split, *transport_split, *reg_target_split, *rec_split, *param_split, *rand_param_split])[1:] # rand_param_split2 (300, 1, 64, 64, 3) if obs_type == 'flat': cont_var_filter = np.tile(np.asarray([e['dist'] == 'cont' for e in batch['observed']['properties'][obs_type]]), 4) not_cont_var_filter = np.tile(np.asarray([e['dist'] != 'cont' for e in batch['observed']['properties'][obs_type]]), 4) if sum(not_cont_var_filter) > 0: helper.visualize_flat(samples_params_np[not_cont_var_filter], save_dir = save_dir, postfix = postfix+'_'+obs_type) if sum(cont_var_filter) > 0: helper.visualize_vectors(samples_params_np[cont_var_filter], save_dir = save_dir, postfix = postfix+'_'+obs_type) if obs_type == 'image': if b_zero_one_range: np.clip(transport_split[0], 0, 1, out=transport_split[0]) np.clip(reg_target_split[0], 0, 1, out=reg_target_split[0]) np.clip(rec_split[0], 0, 1, out=rec_split[0]) np.clip(param_split[0], 0, 1, out=param_split[0]) np.clip(rand_param_split[0], 0, 1, out=rand_param_split[0]) np.clip(rand_param_split2, 0, 1, out=rand_param_split2) samples_params_np = np.array([np.array([]), *real_split, *transport_split, *reg_target_split, *rec_split, *param_split, *rand_param_split])[1:] samples_params_np_interleaved = helper.interleave_data(samples_params_np) helper.visualize_images2(samples_params_np_interleaved, block_size=[sample_split[0].shape[0], len(samples_params_np)], save_dir=save_dir+'_normal/', postfix='normal_'+postfix, postfix2='normal_'+postfix2) helper.visualize_images2(rand_param_split2[:int(np.sqrt(rand_param_split2.shape[0]))**2, ...], block_size=[int(np.sqrt(rand_param_split2.shape[0])), int(np.sqrt(rand_param_split2.shape[0]))], save_dir=save_dir+'_sample_only/', postfix='sample_only_'+postfix, postfix2='sample_only_'+postfix2) if real_data is not None and (real_data.shape == rand_param_split2[:int(np.sqrt(rand_param_split2.shape[0]))**2, ...].shape): helper.visualize_images2(np.concatenate([real_data, rand_param_split2[:int(np.sqrt(rand_param_split2.shape[0]))**2, ...]], axis=0), block_size=[int(np.sqrt(rand_param_split2.shape[0])), 2*int(np.sqrt(rand_param_split2.shape[0]))], save_dir=save_dir+'_sample_and_real/', postfix='sample_and_real_'+postfix, postfix2='sample_and_real_'+postfix2)