Exemple #1
0
	def plot_example_compound_ops_grid(self, **kwargs):

		'''

		Creates some compound ops, and if a policy canv is passed, it will
		try to fit to them.


		'''

		N_rows = kwargs.get('N_rows', 3)
		N_cols = kwargs.get('N_cols', 8)
		blank_canv = torch.zeros(self.get_op_sample()['canv_spec'].shape)

		#op_sample_list = [self.produce_compound_op_sample(force_op_sample=True, return_all_canvs=True, **kwargs) for c in range(N_cols)]
		op_sample_grid = [[self.produce_compound_op_sample(force_op_sample=True, return_all_canvs=True, **kwargs) for c in range(N_cols)] for r in range(N_rows)]

		#op_sample_list = [self.produce_compound_op_sample() for c in range(N_cols)]
		canv_ideal_grid = [[s['canv_ideal'] for s in s_row] for s_row in op_sample_grid]

		canv_params_grid = [[s['params_list'] for s in s_row] for s_row in op_sample_grid]

		canv_boxes_grid = [[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p)) for p in p_l] for p_l in p_l_row] for p_l_row in canv_params_grid]

		plot_utils.plot_image_grid(canv_ideal_grid, highlight_boxes=canv_boxes_grid, **kwargs)
Exemple #2
0
	def plot_inspect_op_grid(self, inspect_dict_list_in, **kwargs):

		'''
		For inspecting samples during PT that scored very low.

		'''


		inspect_dict_list = deepcopy(inspect_dict_list_in)
		max_N = 15
		inspect_dict_list = inspect_dict_list[:max_N]
		#print(inspect_dict_list)

		canv_spec_list = [s['canv_spec'] for s in inspect_dict_list]
		canv_ideal_list = [s['canv_ideal'] for s in inspect_dict_list]

		canv_spec_labels = ['log_prob = {:.2f},\ntarget op = {}'.format(d['log_prob'], d['target_op_str']) for d in inspect_dict_list]

		label_grid = [
			canv_spec_labels,
			[],
		]

		grid = [
			canv_spec_list,
			canv_ideal_list,
		]

		base_title = kwargs.get('base_title', None)
		if base_title is None:
			plot_title = ''
		else:
			plot_title = base_title + '\n' + title_score

		plot_utils.plot_image_grid(grid, label_grid=label_grid, plot_title=plot_title, **kwargs)
def plot_structs(mols, labels, dataset):
    #Plot
    ims_arr = [[
        Draw.MolToImage(k, kekulize=False, size=(100, 100)) for k in mols
    ]]
    #freqs = [str(v) for k,v in sorted_counts[:topk]]
    row_labels_left = [(dataset, '')]
    plot_image_grid(ims_arr,
                    row_labels_left=row_labels_left,
                    c=5,
                    row_labels_right=[],
                    col_labels=labels,
                    super_col_labels=[],
                    col_rotation=0)
    def on_epoch_end(self, epoch, logs={}):
        """
        Plot into Tensorboard a grid of image results.
        :param epoch:   Epoch num
        :param logs:    (unused) Dictionary of loss/metrics value for the epoch
        """

        # Get predictions with current model:
        predicted_images = self.model.predict_on_batch(self.input_images)
        if self.postprocess_fn is not None:
            input_images, predicted_images, target_images = self.postprocess_fn(
                self.input_images, predicted_images, self.target_images)
        else:
            input_images, target_images = self.input_images, self.target_images

        # Fill figure with images:
        grid_imgs = [input_images, predicted_images]
        if target_images is not None:
            grid_imgs.append(target_images)
        self.fig.clf()
        self.fig = plot_image_grid(grid_imgs,
                                   titles=self.image_titles,
                                   figure=self.fig,
                                   grayscale=self.grayscale,
                                   transpose=self.transpose)

        with self.summary_writer.as_default():
            # Transform into summary:
            figure_summary = figure_to_summary(self.fig, self.tag, epoch)

            # # Finally, log it:
            # self.summary_writer.add_summary(figure_summary, global_step=epoch)
        self.summary_writer.flush()
Exemple #5
0
	def plot_inspect_params_grid(self, inspect_dict_list_in, **kwargs):

		'''
		For inspecting samples that scored really badly during PT.
		'''


		inspect_dict_list = deepcopy(inspect_dict_list_in)
		max_N = 10
		inspect_dict_list = inspect_dict_list[:max_N]
		#print(inspect_dict_list)

		canv_spec_list = [s['canv_spec'] for s in inspect_dict_list]
		canv_ideal_list = [s['canv_ideal'] for s in inspect_dict_list]

		canv_spec_labels = ['log_prob = {:.2f},\nmu = {}\nsigma = {}'.format(d['log_prob'], [f'{m:.2f}' for m in d['params_mu']], [f'{sig:.2f}' for sig in d['params_sigma']]) for d in inspect_dict_list]

		boxes_row = [[self.center_xy_wh_to_corner_xy_wh(s['params_sampled']), self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(s['params']))] for s in inspect_dict_list]


		label_grid = [
			canv_spec_labels,
			[],
		]

		grid = [
			canv_spec_list,
			canv_ideal_list,
		]

		highlight_boxes = [
			boxes_row,
			[],
		]

		base_title = kwargs.get('base_title', None)
		if base_title is None:
			plot_title = ''
		else:
			plot_title = base_title + '\n' + title_score

		plot_utils.plot_image_grid(grid, label_grid=label_grid, plot_title=plot_title, highlight_boxes=highlight_boxes, **kwargs)
Exemple #6
0
	def plot_primitives_grid(self, policy_model, **kwargs):

		'''
		Gets some prim examples, uses the policy model to try fitting to them, plots the fitted
		rects on top of them.

		'''

		N_rows = kwargs.get('N_rows', 3)
		N_cols = kwargs.get('N_cols', 6)

		prim_sample_grid = kwargs.get('prim_sample_grid', None)
		if prim_sample_grid is None:
			prim_sample_grid = [[self.get_random_shape_canv(**kwargs) for c in range(N_cols)] for r in range(N_rows)]

		canv_spec_grid = [[s['canv_spec'] for s in row] for row in prim_sample_grid]
		canv_ideal_grid = [[s['canv_ideal'] for s in row] for row in prim_sample_grid]

		nn_output_dict = [[policy_model.policy_params(s['canv_spec']) for s in row] for row in prim_sample_grid]
		nn_action_dict = [[train_utils.get_action_dict(s) for s in row] for row in nn_output_dict]
		params_grid = [[s['params'].squeeze().tolist() for s in row] for row in nn_action_dict]
		f1_grid = [[cu.F1_score(self.primitive_rect(*p), canv_ideal_grid[i][j]) for j,p in enumerate(r)] for i,r in enumerate(params_grid)]


		label_grid = [['F1 score = {:.3f}'.format(f) for f in row] for row in f1_grid]


		boxes_grid = [[[self.center_xy_wh_to_corner_xy_wh(s), self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(self.center_xy_wh_to_grid_corners(s)))] for s in row] for row in params_grid]


		mean_score = np.mean(f1_grid)
		title_score = f'Mean score = {mean_score:.2f}'
		base_title = kwargs.get('base_title', None)
		if base_title is None:
			plot_title = title_score
		else:
			plot_title = base_title + '\n' + title_score

		plot_utils.plot_image_grid(canv_spec_grid, highlight_boxes=boxes_grid, label_grid=label_grid, plot_title=plot_title, **kwargs)

		return prim_sample_grid
Exemple #7
0
	def plot_example_ops_grid(self, **kwargs):

		'''
		Plots example 2-primitive ops, but doesn't do any fitting.

		'''


		N_cols = kwargs.get('N_cols', 10)
		blank_canv = torch.zeros(self.get_op_sample()['canv_spec'].shape)

		op_sample_list = [self.get_op_sample(**kwargs) for c in range(N_cols)]

		#op_sample_list = [self.produce_compound_op_sample() for c in range(N_cols)]
		canv_spec_list = [s['canv_spec'] for s in op_sample_list]
		canv_ideal_list = [s['canv_ideal'] for s in op_sample_list]
		canv_1_true_list = [s['canv_1'] if 'canv_1' in s.keys() else blank_canv for s in op_sample_list]
		canv_2_true_list = [s['canv_2'] if 'canv_2' in s.keys() else blank_canv  for s in op_sample_list]

		canv_params_list = [s['params_list'] for s in op_sample_list]

		grid = [
			canv_ideal_list,
			canv_1_true_list,
			canv_2_true_list,
		]

		boxes_grid = [
			[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p)) for p in p_l] for p_l in canv_params_list],
			[],
			[],
			[],
			[],
		]

		plot_utils.plot_image_grid(grid, highlight_boxes=boxes_grid, **kwargs)
Exemple #8
0




plot_utils.plot_reconstructions(in_data, out_data, output_dir=output_dir, show_plot=show_plot)

plot_utils.plot_transformation(latent_path_recon, output_dir=output_dir, show_plot=show_plot)


plot_utils.plot_2D_latent_space(latent_data[:,:2], label_list=labels, output_dir=output_dir, show_plot=show_plot, rel_fname='latent_space')
plot_utils.plot_2D_latent_space(latent_data[:,:2], label_list=labels, highlight_points=latent_path[:,:2], output_dir=output_dir, show_plot=show_plot, rel_fname='latent_path')

plot_utils.plot_2D_latent_space(latent_data[:,:2], label_list=labels, highlight_points=latent_grid[:,:2], output_dir=output_dir, show_plot=show_plot, rel_fname='latent_grid')

plot_utils.plot_image_grid(latent_grid_recon, output_dir=output_dir, show_plot=show_plot)





in_data_friend_raw = genfromtxt(os.path.join(source_dir, run_label + '_friend_input_batch.txt'), delimiter=',')
out_data_friend_raw = genfromtxt(os.path.join(source_dir, run_label + '_friend_output_batch.txt'), delimiter=',')
in_data_friend = in_data_friend_raw.reshape(-1, img_side_size, img_side_size)
out_data_friend = out_data_friend_raw.reshape(-1, img_side_size, img_side_size)

plot_utils.plot_reconstructions(in_data_friend, out_data_friend, output_dir=output_dir, show_plot=show_plot, rel_fname='friend_recons')


plot_utils.plot_N_lossiest(in_data_raw, out_data_raw, output_dir=output_dir, rel_fname='N_worst.png', show_plot=show_plot)
plot_utils.plot_losses_hist(in_data_raw, out_data_raw, output_dir=output_dir, rel_fname='sample_losses_hist.png', highlight_pts={'in_pts':in_data_friend_raw, 'out_pts':out_data_friend_raw}, show_plot=show_plot)
Exemple #9
0
	def plot_compound_ops_grid(self, **kwargs):

		'''

		Creates some compound ops, and if a policy canv is passed, it will
		try to fit to them.


		'''

		N_cols = kwargs.get('N_cols', 10)
		blank_canv = torch.zeros(self.get_op_sample()['canv_spec'].shape)

		op_sample_list = [self.produce_compound_op_sample(force_op_sample=True, return_all_canvs=True, **kwargs) for c in range(N_cols)]

		#op_sample_list = [self.produce_compound_op_sample() for c in range(N_cols)]
		canv_spec_list = [s['canv_spec'] for s in op_sample_list]
		canv_ideal_list = [s['canv_ideal'] for s in op_sample_list]

		canv_params_list = [s['params_list'] for s in op_sample_list]
		combined_params_list = [self.combine_params_list(s['params_list']) for s in op_sample_list]



		eval_canv_1 = kwargs.get('eval_canv_1', False)
		eval_canv_2 = kwargs.get('eval_canv_2', False)

		if eval_canv_1:

			pm = kwargs.get('policy_model', None)
			assert pm is not None, 'Must supply a policy_model kwarg to eval!'

			target_op_OHE = self.op_str_to_OHE('rect')

			canv_1_list = []
			canv_1_params_list = []
			canv_2_list = []

			for i, target in enumerate(op_sample_list):
				output_dict = pm.policy_canv_1(target['canv_ideal'], target_op_OHE)
				all_canvs_list = target['all_canvs_list']
				all_log_probs = [train_utils.get_log_probs_of_samples(output_dict, canv_1=canv_1, **kwargs)['canv_1_log_prob'].item() for canv_1 in all_canvs_list]

				best_ind = np.argmax(all_log_probs)

				best_canv = train_utils.get_action_dict(output_dict, **kwargs)['canv_1']
				best_params = target['params_list'][best_ind]

				canv_1_list.append(best_canv)
				canv_1_params_list.append(best_params)
				#print([p for i,p in enumerate(target['params_list']) if i!=best_ind])
				combined_params_canv = self.combine_params_list([p for i,p in enumerate(target['params_list']) if i!=best_ind])
				canv_2_list.append(combined_params_canv)

			grid = [
				canv_ideal_list,
				canv_1_list,
			]

			label_grid = [
				['canv_spec' for _ in op_sample_list],
				['sampled canv_1' for _ in op_sample_list],
			]

			boxes_grid = [
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p)) for p in p_l] for p_l in canv_params_list],
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p_l))] for p_l in canv_1_params_list],
			]

		elif eval_canv_2:

			pm = kwargs.get('policy_model', None)
			assert pm is not None, 'Must supply a policy_model kwarg to eval!'

			target_op_OHE = self.op_str_to_OHE('rect')

			canv_1_list = []
			canv_1_params_list = []
			canv_2_list = []
			canv_2_params_list = []

			for i, target in enumerate(op_sample_list):
				output_dict = pm.policy_canv_2(target['canv_ideal'], target['canv_1'], target_op_OHE)
				canv_2 = train_utils.get_action_dict(output_dict, **kwargs)['canv_2']

				canv_1_list.append(target['canv_1'])
				canv_1_params_list.append(target['params_list'][-1])
				canv_2_list.append(canv_2)
				canv_2_params_list.append(target['params_list'][:-1])

			grid = [
				canv_ideal_list,
				canv_1_list,
				canv_2_list,
			]

			label_grid = [
				['canv_spec' for _ in op_sample_list],
				['input canv_1' for _ in op_sample_list],
				['sampled canv_2' for _ in op_sample_list],
			]

			boxes_grid = [
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p)) for p in p_l] for p_l in canv_params_list],
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p_l))] for p_l in canv_1_params_list],
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p)) for p in p_l] for p_l in canv_2_params_list],
			]

		else:

			canv_1_list = [s['canv_1'] if 'canv_1' in s.keys() else blank_canv for s in op_sample_list]
			canv_2_list = [s['canv_2'] if 'canv_2' in s.keys() else blank_canv  for s in op_sample_list]
			canv_1_params_list = [s['params_list'][-1] for s in op_sample_list]
			canv_2_params_list = [s['params_list'][:-1] for s in op_sample_list]

			grid = [
				canv_ideal_list,
				canv_1_list,
				canv_2_list,
			]

			label_grid = [
				['canv_spec' for _ in op_sample_list],
				['ex canv_1' for _ in op_sample_list],
				['ex canv_2' for _ in op_sample_list],
			]

			boxes_grid = [
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p)) for p in p_l] for p_l in canv_params_list],
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p_l))] for p_l in canv_1_params_list],
				[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p)) for p in p_l] for p_l in canv_2_params_list],
			]




		plot_utils.plot_image_grid(grid, label_grid=label_grid, highlight_boxes=boxes_grid, **kwargs)
Exemple #10
0
	def plot_operations_grid(self, policy_model, **kwargs):

		'''
		Creates some simple 2-primitive operation canvs, uses the policy model
		to try and figure out the right canvs.


		'''


		N_cols = 10
		blank_canv = torch.zeros(self.get_op_sample()['canv_spec'].shape)

		op_sample_list = kwargs.get('op_sample_list', None)
		if op_sample_list is None:
			op_sample_list = [self.get_op_sample(**kwargs) for c in range(N_cols)]

		#op_sample_list = [self.produce_compound_op_sample() for c in range(N_cols)]
		canv_spec_list = [s['canv_spec'] for s in op_sample_list]
		canv_ideal_list = [s['canv_ideal'] for s in op_sample_list]
		canv_1_true_list = [s['canv_1'] for s in op_sample_list]
		canv_2_true_list = [s['canv_2'] for s in op_sample_list]

		canv_1_params_list = [s['canv_1_params'] for s in op_sample_list]
		canv_2_params_list = [s['canv_2_params'] for s in op_sample_list]

		target_op_OHE = [self.op_str_to_OHE(s['op_str']) for s in op_sample_list]

		op_action_dict_list = [train_utils.get_action_dict(policy_model.policy_op(c), **kwargs) for c in canv_spec_list]

		#canv_action_dict_list = [train_utils.get_action_dict(policy_canv(s['canv_spec'], t)) if s['op_str']!='rect' else None for s,t in zip(op_sample_list, target_op_OHE)]
		canv_1_action_dict_list = [train_utils.get_action_dict(policy_model.policy_canv_1(s['canv_spec'], t), **kwargs) if s['op_str']!='rect' else None for s,t in zip(op_sample_list, target_op_OHE)]

		canv_1_list = [n['canv_1'] if n is not None else blank_canv for n in canv_1_action_dict_list]

		canv_2_action_dict_list = [train_utils.get_action_dict(policy_model.policy_canv_2(s['canv_spec'], a, t), **kwargs) if s['op_str']!='rect' else None for s,t,a in zip(op_sample_list, target_op_OHE, canv_1_list)]

		canv_2_list = [n['canv_2'] if n is not None else blank_canv for n in canv_2_action_dict_list]

		canv_1_labels = ['canv_1' if n is not None else '' for n in canv_1_action_dict_list]
		canv_2_labels = ['canv_2' if n is not None else '' for n in canv_2_action_dict_list]

		target_op_strs = [s['op_str'] for s in op_sample_list]
		sampled_op_strs = [self.op_ind_to_op_str(n['op_ind']) for n in op_action_dict_list]

		#target_op_labels = ['using target:\n{}'.format(s) for s in target_op_strs]
		#sampled_op_labels = ['using sampled:\n{}'.format(s) for s in sampled_op_strs]



		canv_spec_op_list = [self.apply_op(s, c_1, c_2) if s != 'rect' else c_spec for s, c_1, c_2, c_spec in zip(target_op_strs, canv_1_list, canv_2_list, canv_spec_list)]
		#canv_sampled_op_list = [self.apply_op(s, c_1, c_2) if s != 'rect' else blank_canv for s, c_1, c_2 in zip(sampled_op_strs, canv_1_list, canv_2_list)]

		recon_score_list = [cu.F1_score(ideal, canv_op) for ideal, canv_op in zip(canv_ideal_list, canv_spec_op_list)]

		canv1_score_list = [cu.F1_score(canv_1, canv_1_true) if cu.F1_score(canv_1, canv_1_true) > cu.F1_score(canv_1, canv_2_true) else cu.F1_score(canv_1, canv_2_true) for canv_1_true, canv_2_true, canv_1 in zip(canv_1_true_list, canv_2_true_list, canv_1_list)]
		canv2_score_list = [cu.F1_score(canv_2, canv_1_true) if cu.F1_score(canv_2, canv_1_true) > cu.F1_score(canv_2, canv_2_true) else cu.F1_score(canv_2, canv_2_true) for canv_1_true, canv_2_true, canv_2 in zip(canv_1_true_list, canv_2_true_list, canv_2_list)]

		canv1_params_list = [canv_1_p if cu.F1_score(canv_1, canv_1_true) > cu.F1_score(canv_1, canv_2_true) else canv_2_p for canv_1_true, canv_2_true, canv_1, canv_1_p, canv_2_p in zip(canv_1_true_list, canv_2_true_list, canv_1_list, canv_1_params_list, canv_2_params_list)]
		canv2_params_list = [canv_1_p if cu.F1_score(canv_2, canv_1_true) > cu.F1_score(canv_2, canv_2_true) else canv_2_p for canv_1_true, canv_2_true, canv_2, canv_1_p, canv_2_p in zip(canv_1_true_list, canv_2_true_list, canv_2_list, canv_1_params_list, canv_2_params_list)]

		boxes_grid = [
			[],
			[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p))] for p in canv1_params_list],
			[[self.center_xy_wh_to_corner_xy_wh(self.corners_to_center_xy_wh(p))] for p in canv2_params_list],
			[],
		]

		op_labels = ['sampled: {}\nusing target:\n{}\nScore: {:.3f}'.format(s, t, score) for s, t, score in zip(sampled_op_strs, target_op_strs, recon_score_list)]

		label_grid = [
			target_op_strs,
			canv_1_labels,
			canv_2_labels,
			op_labels,
		]

		grid = [
			canv_spec_list,
			canv_1_list,
			canv_2_list,
			canv_spec_op_list,
		]

		mean_score_recon = np.mean(recon_score_list)
		mean_score_canv1 = np.mean(canv1_score_list)
		mean_score_canv2 = np.mean(canv2_score_list)

		title_score = f'Mean recon score = {mean_score_recon:.2f}, Mean canv_1 score = {mean_score_canv1:.2f}, Mean canv_2 score = {mean_score_canv2:.2f}, '
		base_title = kwargs.get('base_title', None)
		if base_title is None:
			plot_title = title_score
		else:
			plot_title = base_title + '\n' + title_score

		plot_utils.plot_image_grid(grid, label_grid=label_grid, plot_title=plot_title, highlight_boxes=boxes_grid, **kwargs)

		return op_sample_list