def load_data(path_data, action_space, force_reload=False): path_data_processed = path_data + ', processed' file_data_processed = path_data_processed + '/data' if not force_reload and os.path.exists(file_data_processed): print(f'load data from {file_data_processed}') vs = load_vars(file_data_processed) return vs print(f'load data from {path_data}') tools.mkdir(path_data_processed) files = tools.get_files(path_rel=path_data, sort=True) # inputs_final, outputs_final = np.zeros((0, 2)), np.zeros((0, 4)) inputs_final, outputs_final = np.zeros((0, 2 * action_space)), np.zeros((0, 4 * action_space)) counts = np.zeros((len(files)), dtype=np.int) for ind, f in enumerate(files): mu0s_ats_batch, logsigma0s_batch, ress = load_vars(f) inputs = np.concatenate((mu0s_ats_batch, logsigma0s_batch), axis=-1) max_values = np.array([res['max'].x for res in ress]) min_values = np.array([res['min'].x for res in ress]) outputs = np.concatenate((max_values, min_values), axis=-1) inputs_final = np.concatenate((inputs_final, inputs)) # shape:(None, 2) outputs_final = np.concatenate((outputs_final, outputs)) # shape:(None, 4) counts[ind] = mu0s_ats_batch.shape[0] weights = [] cnt_normalize = counts.mean() for cnt in counts: weight = cnt_normalize * 1. / cnt * np.ones(cnt) weights.append(weight) weights = np.concatenate(weights, axis=0) # final = np.concatenate((inputs_final, outputs_final), axis=-1) # --- delete nan and inf # final = final[~np.isnan(final).any(axis=1)] # final = final[~np.isinf(final).any(axis=1)] inds_reserve = np.logical_and(~np.isnan(outputs_final).any(axis=1), ~np.isinf(outputs_final).any(axis=1)) inputs_final = inputs_final[inds_reserve] outputs_final = outputs_final[inds_reserve] weights = weights[inds_reserve] # --- shuffle # np.random.shuffle(final) N = inputs_final.shape[0] inds_shuffle = np.random.permutation(N) inputs_final = inputs_final[inds_shuffle] outputs_final = outputs_final[inds_shuffle] weights = weights[inds_shuffle] # inputs_final, outputs_final = np.split(final, indices_or_sections=[2], axis=-1) ind_split = -500 train_x, train_y, train_weight = \ inputs_final[:ind_split], outputs_final[:ind_split], weights[:ind_split] eval_x, eval_y, eval_weight = \ inputs_final[ind_split:], outputs_final[ind_split:], weights[ind_split:] save_vars(file_data_processed, train_x, train_y, train_weight, eval_x, eval_y, eval_weight) return train_x, train_y, train_weight, eval_x, eval_y, eval_weight
def get_tabular(self, delta): save_path = f'{path_root_tabular}/{delta:.16f}_atari' if delta in self.deltas_dict: pass # TODO: file lock elif os.path.exists(save_path) and os.path.getsize(save_path) > 0: self.deltas_dict[delta] = tools.load_vars(save_path) else: with tools_process.FileLocker( f'{path_root_tabluar_locker}/{delta:.16f}'): if os.path.exists( save_path) and os.path.getsize(save_path) > 0: self.deltas_dict[delta] = tools.load_vars(save_path) else: self.deltas_dict[delta] = self.create_tabular(delta) tools.save_vars(save_path, self.deltas_dict[delta]) return self.deltas_dict[delta]
def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=1, nminibatches=4, noptepochs=4, save_interval=10, load_path=None, clipped_type, args=None): clipped_type = ClipType[clipped_type] print(f'Logger.CURRENT.dir is {logger.Logger.CURRENT.dir}') if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) total_timesteps = int(total_timesteps) nenvs = env.num_envs ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * nsteps nbatch_train = nbatch // nminibatches set_save_load_env(env) make_model = lambda: Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, clipped_type=clipped_type, args=args) if save_interval and args.model_dir: import cloudpickle with open(osp.join(args.model_dir, 'make_model.pkl'), 'wb') as fh: fh.write(cloudpickle.dumps(make_model)) model = make_model() writer = model.create_summary_writer(args.log_dir) if load_path is not None: model.load(load_path) runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) epinfobuf = deque(maxlen=100) # epinfobuf = deque(maxlen=3) tfirststart = time.time() # assert len(ac_space.shape) == 1 if clipped_type == ClipType.kl2clip: if isinstance(env.action_space, gym.spaces.box.Box): from baselines.TRPPO.KL2Clip_reduce_v3.KL2Clip_reduce import KL2Clip kl2clip = KL2Clip(dim=ac_space.shape[0], batch_size=nsteps, use_tabular=args.use_tabular) elif isinstance(env.action_space, gym.spaces.discrete.Discrete): from baselines.TRPPO.KL2Clip_discrete.KL2Clip_discrete import KL2Clip kl2clip = KL2Clip(dim=ac_space.n, batch_size=nsteps, use_tabular=args.use_tabular) else: raise NotImplementedError('Please run atari or mujoco!') assert not (args.cliprange is None and args.delta_kl is None) if args.cliprange is None and args.delta_kl is not None: args.cliprange = kl2clip.get_cliprange_by_delta(args.delta_kl) print('********************************') print( f'We set cliprange={args.cliprange} according to delta_kl={args.delta_kl}, dim={ac_space.shape[0]}' ) cliprange = args.cliprange if isinstance(cliprange, float) or cliprange is None: cliprange = constfn(cliprange) else: assert callable(cliprange) nupdates = total_timesteps // nbatch eprewmean_max = -np.inf alphas_kl2clip_decay = np.zeros(nupdates, dtype=np.float32) alphas_kl2clip_decay[0:nupdates // 3] = 1 alphas_kl2clip_decay[nupdates // 3:] = np.linspace( 1, -0.5, nupdates - nupdates // 3) for update in range(1, nupdates + 1): assert nbatch % nminibatches == 0 nbatch_train = nbatch // nminibatches tstart = time.time() frac = 1.0 - (update - 1.0) / nupdates lrnow = lr(frac) if isinstance(env.action_space, gym.spaces.Box): cliprangenow = cliprange(frac) elif isinstance(env.action_space, gym.spaces.Discrete): cliprangenow = (lambda _: cliprange(None) * _)(frac) # anneal # using runner to sample data from model (old Pi_theta) obs, returns, masks, actions, values, neglogpacs, policyflats, states, epinfos = runner.run( ) # pylint: disable=E0632 # Add xiaoming advs = returns - values advs = (advs - advs.mean()) / (advs.std() + 1e-8) if isinstance(env.action_space, gym.spaces.box.Box): epinfobuf.clear() epinfobuf.extend(epinfos) mblossvals = [] if clipped_type == ClipType.kl2clip: if isinstance(env.action_space, gym.spaces.Discrete): pas = np.exp(-neglogpacs) else: pas = None ress = kl2clip( mu0_logsigma0_cat=policyflats, a=actions, pas=pas, delta=args.delta_kl, clipcontroltype=args.kl2clip_clipcontroltype, cliprange=cliprange, ) cliprange_max = ress.ratio.max cliprange_min = ress.ratio.min save_vars(osp.join(args.model_dir, 'cliprange_max', f'{update}'), cliprange_max) save_vars(osp.join(args.model_dir, 'cliprange_min', f'{update}'), cliprange_min) if not args.model_dir.__contains__('Humanoid') and isinstance( env.action_space, gym.spaces.box.Box): save_vars(osp.join(args.model_dir, 'actions', f'{update}'), actions) elif clipped_type == ClipType.judgekl: pass if not args.model_dir.__contains__('Humanoid') and isinstance( env.action_space, gym.spaces.box.Box): save_vars(osp.join(args.model_dir, 'mu0_logsigma0', f'{update}'), policyflats) save_vars(osp.join(args.model_dir, 'advs', f'{update}'), advs) if states is None: # nonrecurrent version inds = np.arange(nbatch) kls = [] ratios = [] for ind_epoch in range(noptepochs): np.random.shuffle(inds) for start in range(0, nbatch, nbatch_train): end = start + nbatch_train # mini-batch indexes mbinds = inds[start:end] slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs, advs)) policyflats_batch, = (arr[mbinds] for arr in (policyflats, )) if clipped_type == ClipType.kl2clip: cliprange_min_batch, cliprange_max_batch = ( arr[mbinds] for arr in (cliprange_min, cliprange_max)) if isinstance(env.action_space, gym.spaces.Discrete): cliprange_min_batch = 1 - ( 1. - cliprange_min_batch) * frac # anneal cliprange_max_batch = 1 + (cliprange_max_batch - 1) * frac *ress, kl, ratio = model.train( lrnow, *slices, cliprange=cliprangenow, cliprange_min=cliprange_min_batch, cliprange_max=cliprange_max_batch, policyflats=policyflats_batch) else: *ress, kl, ratio = model.train( lrnow, *slices, cliprange=cliprangenow, policyflats=policyflats_batch) if ind_epoch == noptepochs - 1: kls.append(kl) ratios.append(ratio) mblossvals.append(ress) inds2position = {} for position, ind in enumerate(inds): inds2position[ind] = position inds_reverse = [inds2position[ind] for ind in range(len(inds))] kls, ratios = (np.concatenate(arr, axis=0)[inds_reverse] for arr in (kls, ratios)) save_vars(osp.join(args.model_dir, 'kls, ratios', f'{update}'), kls, ratios) else: # recurrent version assert nenvs % nminibatches == 0 envsperbatch = nenvs // nminibatches envinds = np.arange(nenvs) flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps) envsperbatch = nbatch_train // nsteps for _ in range(noptepochs): np.random.shuffle(envinds) for start in range(0, nenvs, envsperbatch): end = start + envsperbatch mbenvinds = envinds[start:end] mbflatinds = flatinds[mbenvinds].ravel() slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mbstates = states[mbenvinds] # TODO: KL2CLip raise NotImplementedError mblossvals.append( model.train(lrnow, cliprangenow, *slices, mbstates)) lossvals = np.mean(mblossvals, axis=0) tnow = time.time() fps = int(nbatch / (tnow - tstart)) eprewmean_newmax = False if update % log_interval == 0 or update == 1: ev = explained_variance(values, returns) logger.logkv("serial_timesteps", update * nsteps) logger.logkv("nupdates", update) logger.logkv("total_timesteps", update * nbatch) logger.logkv("fps", fps) logger.logkv("explained_variance", float(ev)) eprewmean = safemean([epinfo['r'] for epinfo in epinfobuf]) logger.logkv('eprewmean', eprewmean) if eprewmean > eprewmean_max: eprewmean_newmax = True eprewmean_max = eprewmean logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) logger.logkv('time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv(lossname, lossval) # using tensorboard to log these data in a loop # print(logger.Logger.CURRENT.name2val) summary = tf.Summary() for k, v in logger.Logger.CURRENT.name2val.items(): summary.value.add(tag=k, simple_value=v) # [summary.value.add(tag=k, simple_value=v) for k, v in logger.Logger.DEFAULT.name2val.items()] timesteps = update * nbatch writer.add_summary(summary, global_step=timesteps) logger.dumpkvs() if save_interval and (update % save_interval == 0 or update == 1 or update == nupdates # or eprewmean_newmax # TODO: atari ) and args.model_dir: checkdir = osp.join(args.model_dir, 'checkpoints') os.makedirs(checkdir, exist_ok=True) savepath = osp.join(checkdir, '%.5i' % update) print('Saving to', savepath) model.save(savepath) env.close() return model
def load_data_normal(path_data, USE_MULTIPROCESSING=True): path_save = f'{path_data}/train_preprocessed_reduce_v3' if os.path.exists(f'{path_save}/data'): print(f'load data from {path_save}/data') vs = load_vars(f'{path_save}/data') return vs tools.mkdir(f'{path_data}/train_preprocessed') files = tools.get_files(path_rel=path_data, only_sub=False, sort=False, suffix='.pkl') actions, deltas, max_mu_logsigma, min_mu_logsigma = [], [], [], [] for ind, f in enumerate(files[:1]): a_s_batch, _, _, ress_tf = load_vars(f) actions.append(a_s_batch) deltas.append(np.ones_like(a_s_batch) * ress_tf.delta) min_mu_logsigma.append(ress_tf.x.min) max_mu_logsigma.append(ress_tf.x.max) actions = np.concatenate(actions, axis=0) deltas = np.concatenate(deltas, axis=0) min_mu_logsigma = np.concatenate(min_mu_logsigma, axis=0) max_mu_logsigma = np.concatenate(max_mu_logsigma, axis=0) min_mu_tfopt, _ = np.split(min_mu_logsigma, indices_or_sections=2, axis=-1) max_mu_tfopt, _ = np.split(max_mu_logsigma, indices_or_sections=2, axis=-1) time0 = time.time() calculate_mu = get_calculate_mu_func(True) # TODO: 以下为mu_logsigma_fsolve if USE_MULTIPROCESSING: p = multiprocessing.Pool(4) min_mu_fsolve = p.map(calculate_mu, zip(min_mu_tfopt, actions, deltas)) max_mu_fsolve = p.map(calculate_mu, zip(max_mu_tfopt, actions, deltas)) else: min_mu_fsolve = list(map(calculate_mu, zip(min_mu_tfopt, actions, deltas))) max_mu_fsolve = list(map(calculate_mu, zip(max_mu_tfopt, actions, deltas))) min_mu_fsolve = [_[0] for _ in min_mu_fsolve] max_mu_fsolve = [_[0] for _ in max_mu_fsolve] # f_mu_to_logsigma = lambda m, a: (m - a) * (m ** 2 - a * u - 1) / a time1 = time.time() print(time1 - time0) mu_tf_opt = np.concatenate((min_mu_tfopt, max_mu_tfopt), axis=1) mu_fsolve = np.stack( (np.concatenate(min_mu_fsolve, axis=0).squeeze(), np.concatenate(max_mu_fsolve, axis=0).squeeze()) , axis=1) print(mu_tf_opt - mu_fsolve) # exit() inds_shuffle = np.random.permutation(actions.shape[0]) all_ = np.concatenate((actions, deltas, mu_fsolve), axis=1)[inds_shuffle] all_ = all_[~np.isnan(all_).any(axis=1)] inputs_all, outputs_all = np.split(all_, indices_or_sections=2, axis=1) # (actions, deltas) (lambda_min_true, lambda_max_true) weights = np.ones(shape=(inputs_all.shape[0],)) print(outputs_all.shape) ind_split = -3000 train_x, train_y, train_weight = \ inputs_all[:ind_split], outputs_all[:ind_split], weights[:ind_split] eval_x, eval_y, eval_weight = \ inputs_all[ind_split:], outputs_all[ind_split:], weights[ind_split:] save_vars(f'{path_save}/data', train_x, train_y, train_weight, eval_x, eval_y, eval_weight) return train_x, train_y, train_weight, eval_x, eval_y, eval_weight,
def tes_3d_data(): if tools.ispc('xiaoming'): path_root = '/media/root/新加卷/KL2Clip' else: path_root = '' import plt_tools from baselines.common.tools import load_vars, save_vars import matplotlib.pyplot as plt if 1: dim = 1 # tf.logging.set_verbosity(tf.logging.INFO) files = [] path_data = f'{path_root}/data/train' for dir in sorted(os.listdir(path_data)): dir_pickle = os.path.join(path_data, dir) try: file_path = os.listdir(dir_pickle)[0] if os.listdir(dir_pickle)[0].endswith('pkl') else \ os.listdir(dir_pickle)[1] except: continue files.append(os.path.join(dir_pickle, file_path)) tfoptsssss = [] scipyfsolvesssss = [] a_delta = [] # exit() # files = ['/media/root/新加卷/KL2Clip/data/train/dim=1, delta=0.0902, train/logsigma0=[0].pkl'] for ind, f in enumerate(files): # enumerate(files[1::100]): print(f) actions, _, _, ress_tf = load_vars(f) delta = ress_tf.delta # min_mu_logsigma = ress_tf.x.min # max_mu_logsigma = ress_tf.x.max ratio_min_tfopt, ratio_max_tfopt = ress_tf.ratio.min, ress_tf.ratio.max kl2clip = KL2Clip(dim=dim) x0 = np.zeros(shape=(actions.shape[0], 2), dtype=np.float32) # sort by actions inds = np.argsort(actions, axis=0) inds = inds.reshape(-1) actions = actions[inds] ratio_min_tfopt, ratio_max_tfopt = ratio_min_tfopt[ inds], ratio_max_tfopt[inds] ress = kl2clip(mu0_logsigma0_cat=x0, a=actions, delta=delta) ratio_min_scipyfsolve, ratio_max_scipyfsolve = ress.ratio.min, ress.ratio.max a_delta.append( np.concatenate((actions, delta * np.ones_like(actions)), axis=1)) tfoptsssss.append(ratio_max_tfopt) scipyfsolvesssss.append(ratio_max_scipyfsolve) save_vars('aa.pkl', a_delta, tfoptsssss, scipyfsolvesssss) a_delta, tfoptsssss, scipyfsolvesssss = load_vars('aa.pkl') def filter(arr): for ind in range(len(arr)): arr[ind] = arr[ind][0::30] return arr a_delta, tfoptsssss, scipyfsolvesssss = [ filter(item) for item in (a_delta, tfoptsssss, scipyfsolvesssss) ] a_delta = np.concatenate(a_delta, axis=0) tfoptsssss = np.concatenate(tfoptsssss, axis=0) scipyfsolvesssss = np.concatenate(scipyfsolvesssss, axis=0) fig = plt.figure() ax = fig.gca(projection='3d') ax.view_init(0, 0) ax.scatter(a_delta[:, 0], a_delta[:, 1], tfoptsssss, '_tfopt', s=1, color='black') ax.scatter(a_delta[:, 0], a_delta[:, 1], scipyfsolvesssss, '_scipyfsolve', s=1, color='red') plt_tools.set_postion() plt_tools.set_size() # plt_tools.set_equal() plt.show()
def prepare_data(dim, delta, sharelogsigma, clipcontroltype, cliprange, clip_clipratio, search_delta=False): global ress_tf_last path_data = path_root + '/KL2Clip/data/train_lambda' Name = f'dim={dim}, delta={delta}, train' path_data_processed = path_data + f'/{Name}' tools.mkdir(path_data_processed) if dim == 1: logsigma0s = np.array([0]) else: raise NotImplementedError logsigma0s = logsigma0s.reshape((-1, dim)) batch_size = 2048 mu = np.zeros((dim, )) opt = KL2Clip(dim=dim, batch_size=batch_size, sharelogsigma=sharelogsigma, clipcontroltype=clipcontroltype, cliprange=cliprange) def get_fn_sample(): mu0 = tf.placeholder(shape=[dim], dtype=tf.float32) a = tf.placeholder(shape=[batch_size, dim], dtype=tf.float32) logsigma0 = tf.placeholder(shape=[dim], dtype=tf.float32) sample_size = tf.placeholder(shape=(), dtype=tf.int32) dist = DiagGaussianPd(tf.concat((mu0, logsigma0), axis=0)) samples = dist.sample(sample_size) fn_sample = U.function([mu0, logsigma0, sample_size], samples) fn_p = U.function([mu0, logsigma0, a], dist.p(a)) return fn_sample, fn_p sess = U.make_session(make_default=True) results = [] fn_sample, fn_p = get_fn_sample() for logsigma0 in logsigma0s: prefix_save = f'{path_data_processed}/logsigma0={logsigma0}' Name_f = f"{Name},logsigma0={logsigma0}" file_fig = f'{prefix_save}.png' # a_s_batch = fn_sample( mu, logsigma0, batch_size ) a_s_batch = np.linspace(-5, 5, batch_size).reshape((-1, 1)) logsigma0s_batch = np.tile(logsigma0, (batch_size, 1)) print(a_s_batch.max(), a_s_batch.min()) # --- sort the data: have problem in 2-dim # inds = np.argsort(a_s_batch, axis=0) # inds = inds.reshape(-1) # a_s_batch = a_s_batch[inds] # logsigma0s_batch = logsigma0s_batch[inds] # tools.reset_time() # a_s_batch.fill(0) # print(a_s_batch.shape) # a_s_batch[0, :]=0 # if search_delta: # for i in range( batch_size): # a_s_batch[i,:] = 0.001 * (batch_size-i) if not os.path.exists(f'{prefix_save}.pkl'): # ress_tf = opt( mu0_logsigma0_tuple=(a_s_batch, logsigma0s_batch), a=None, delta=delta, clip_clipratio=clip_clipratio) ress_tf = opt(mu0_logsigma0_tuple=(np.zeros_like(logsigma0s_batch), logsigma0s_batch), a=a_s_batch, delta=delta, clip_clipratio=clip_clipratio) print(a_s_batch[0], ress_tf.x.max[0], ress_tf.x.min[0]) save_vars(f'{prefix_save}.pkl', a_s_batch, logsigma0, logsigma0s_batch, ress_tf) print(prefix_save) a_s_batch, logsigma0, logsigma0s_batch, ress_tf = load_vars( f'{prefix_save}.pkl') if search_delta: results.append(ress_tf) break if cliprange == clipranges[0]: # TODO tmp fig = plt.figure(figsize=(20, 10)) markers = ['^', '.'] colors = [['blue', 'red'], ['green', 'hotpink']] # for ind, opt_name in enumerate(['max']): for ind, opt_name in enumerate(['max', 'min']): # if ind == 1: # continue # --- plot tensorflow result ratios, cons = ress_tf.ratio[opt_name], ress_tf.con[opt_name] print( f'clip-{opt_name}_mean:{ratios.mean()}, clip-{opt_name}_min:{ratios.min()}, clip-{opt_name}_max:{ratios.max()}' ) if search_delta: continue if DEBUG: pass inds_good = cons <= get_ConstraintThreshold(ress_tf.delta) inds_bad = np.logical_not(inds_good) if dim == 1: if ind == 0 and 1: ps = fn_p(mu, logsigma0, a_s_batch) # +np.abs(ps.max()) + 1 ratio_new = -np.log(ps) ratio_new = ratio_new - ratio_new.min() + ratios.min() alpha = np.exp(-ps * 2) print(alpha) # plt.scatter(a_s_batch, ratio_new, s=5, label='ratio_new0') ratio_new = ratio_new.min() + alpha * (ratio_new - ratio_new.min()) # plt.scatter( a_s_batch, ratio_new, s=5, label='ratio_new1' ) # ps = -ps # ratios = ps - ps.min() + ratios.min() # print( ps.min() ) # ratios_new =np.square( a_s_batch-mu ) * np.exp( -logsigma0 ) # ratio_min = ps / (ps.max()-ps.min()) * ress_tf.ratio.min.max() # plt.scatter(a_s_batch, ratio_min, s=5, label='square') # plt.scatter(a_s_batch, 1./ratio_min, s=5, label='square') # plt.scatter(a_s_batch, 1./ratios, s=5, label='1/max') def plot_new(alpha): clip_max_new, clip_min_new = get_clip_new( alpha, ress_tf.ratio['max'], ress_tf.ratio['min'], clipcontroltype=clipcontroltype) plt.scatter(a_s_batch, clip_max_new, s=5, label=f'clip_max_{alpha}') plt.scatter(a_s_batch, clip_min_new, s=5, label=f'clip_min_{alpha}') if ind == 0: pass # plot_new(0.5) # plot_new(0.5) # plot_new(-1) plt.scatter(a_s_batch[inds_good], ratios[inds_good], label='ratio_predict-good_' + opt_name, s=5, color=colors[ind][0], marker=markers[ind]) plt.scatter(a_s_batch[inds_bad], ratios[inds_bad], label='ratio_predict-bad_' + opt_name, s=5, color=colors[ind][1], marker=markers[ind]) elif dim == 2: ax = fig.gca(projection='3d') # ax.view_init(30, 30) ax.view_init(90, 90) # ax.plot_trisurf(a_s_batch[:, 0], a_s_batch[:, 1], ratios) ax.scatter(a_s_batch[inds_good, 0], a_s_batch[inds_good, 1], ratios[inds_good], label='ratio_predict-good_' + opt_name, s=5, color=colors[ind][0], marker=markers[ind]) ax.scatter(a_s_batch[inds_bad, 0], a_s_batch[inds_bad, 1], ratios[inds_bad], label='ratio_predict-bad_' + opt_name, s=5, color=colors[ind][1], marker=markers[ind]) if dim <= 2 and not search_delta: plt.title( Name_f + f'\nstep:{ress_tf.step},rate_satisfycon:{ress_tf.rate_satisfycon_}, rate_statisfydifference_:{ress_tf.rate_statisfydifference_}, difference_max_:{ress_tf.difference_max_}' ) plt.legend(loc='best') if not DEBUG: plt.savefig(file_fig) opt.close() if dim <= 2 and not search_delta: if DEBUG: if cliprange == clipranges[-1]: plt_tools.set_postion() plt.show() plt.close()