def generate_oae_action(known_transisitons): print("listing actions") actions = oae.encode_action(known_transisitons, batch_size=1000).round() histogram = np.squeeze(actions.sum(axis=0, dtype=int)) available_actions = np.zeros( (np.count_nonzero(histogram), actions.shape[1], actions.shape[2]), dtype=int) for i, pos in enumerate(np.where(histogram > 0)[0]): available_actions[i][0][pos] = 1 N = known_transisitons.shape[1] // 2 states = known_transisitons.reshape(-1, N) print("start generating transitions") y = oae.decode([ # s1,s2,s3,s1,s2,s3,.... repeat_over(states, len(available_actions), axis=0), # a1,a1,a1,a2,a2,a2,.... np.repeat(available_actions, len(states), axis=0),], batch_size=1000) \ .round().astype(np.int8) print("remove known transitions") y = set_difference(y, known_transisitons) print("shuffling") random.shuffle(y) return y
def type2(invalid, message): nonlocal c c += 1 invalid = set_difference(invalid, valid) print("invalid", c, invalid.shape, "---", message) discriminator.report(invalid, train_data_to=np.zeros((len(invalid), ))) print( "type2 error:", np.mean( np.round(discriminator.discriminate(invalid, batch_size=1000))) * 100, "%") p = latplan.util.puzzle_module(sae.local("")) count = 0 batch = 10000 for i in range(len(invalid) // batch): pre_images = sae.decode_binary(invalid[batch * i:batch * (i + 1), :N], batch_size=1000) suc_images = sae.decode_binary(invalid[batch * i:batch * (i + 1), N:], batch_size=1000) validation = p.validate_transitions([pre_images, suc_images]) count += np.count_nonzero(validation) print(count, "valid actions in invalid", c)
def generate_nop(data): dim = data.shape[1] // 2 pre, suc = data[:, :dim], data[:, dim:] pre = np.concatenate((pre, suc), axis=0) data_invalid = np.concatenate((pre, pre), axis=1) data_invalid = set_difference(data_invalid, data) return data_invalid
def permute_suc(data): dim = data.shape[1] // 2 pre, suc = data[:, :dim], data[:, dim:] suc_invalid = np.copy(suc) random.shuffle(suc_invalid) data_invalid = np.concatenate((pre, suc_invalid), axis=1) data_invalid = set_difference(data_invalid, data) return data_invalid
def generate_random_action(data, sae): # reconstructable, maybe invalid dim = data.shape[1] // 2 pre, suc = data[:, :dim], data[:, dim:] from state_discriminator3 import generate_random pre = np.concatenate((pre, suc), axis=0) suc = np.concatenate( (generate_random(pre, sae), generate_random(pre, sae)), axis=0)[:len(pre)] actions_invalid = np.concatenate((pre, suc), axis=1) actions_invalid = set_difference(actions_invalid, data) return actions_invalid
def prepare_oae_per_action_PU3(known_transisitons): print("", sep="\n") N = known_transisitons.shape[1] // 2 states = known_transisitons.reshape(-1, N) oae = default_networks['ActionAE'](ae.local("_aae/")).load() actions = oae.encode_action(known_transisitons, batch_size=1000).round().astype(int) L = actions.shape[2] assert L > 1 histogram = np.squeeze(actions.sum(axis=0)) print(histogram) sd3 = default_networks['PUDiscriminator'](ae.local("_sd3/")).load() try: cae = default_networks['SimpleCAE'](sae.local("_cae/")).load() combined_discriminator = default_networks['CombinedDiscriminator'](ae, cae, sd3) except: combined_discriminator = default_networks['CombinedDiscriminator2']( ae, sd3) for label in range(L): print("label", label) known_transisitons_for_this_label = known_transisitons[np.where( actions[:, :, label] > 0.5)[0]] if len(known_transisitons_for_this_label) == 0: yield None else: _actions = np.zeros( (len(states), actions.shape[1], actions.shape[2]), dtype=int) _actions[:, :, label] = 1 y = oae.decode([states, _actions], batch_size=1000).round().astype(np.int8) # prune invalid states ind = np.where( np.squeeze(combined_discriminator(y[:, N:], batch_size=1000)) > 0.5)[0] y = y[ind] y = set_difference(y, known_transisitons_for_this_label) print(y.shape, known_transisitons_for_this_label.shape) train_in, train_out, test_in, test_out = prepare_binary_classification_data( known_transisitons_for_this_label, y) yield (train_in, train_out, test_in, test_out)
def generate_random(data, sae, batch=None): import sys threshold = sys.float_info.epsilon rate_threshold = 0.99 max_repeat = 50 def regenerate(sae, data): images = sae.decode(data, batch_size=2000) data_invalid_rec = sae.encode(images, batch_size=2000) return data_invalid_rec def regenerate_many(sae, data): loss = 1000000000 for i in range(max_repeat): data_rec = regenerate(sae, data) prev_loss = loss loss = bce(data, data_rec) if len(data) > 3000: print(loss, loss / prev_loss) data = data_rec if (loss / prev_loss > rate_threshold): if len(data) > 3000: print("improvement saturated: loss / prev_loss = ", loss / prev_loss, ">", rate_threshold) break if loss <= threshold: print("good amount of loss:", loss, "<", threshold) break return data.round().astype(np.int8) def prune_unreconstructable(sae, data): rec = regenerate(sae, data) loss = bce(data, rec, (1, )) return data[np.where(loss < threshold)[0]] if batch is None: batch = data.shape[0] N = data.shape[1] data_invalid = random.randint(0, 2, (batch, N), dtype=np.int8) data_invalid = regenerate_many(sae, data_invalid) data_invalid = prune_unreconstructable(sae, data_invalid) data_invalid = set_difference(data_invalid.round(), data.round()) return data_invalid
def generate_aae_action(known_transisitons): N = known_transisitons.shape[1] // 2 states = known_transisitons.reshape(-1, N) def repeat_over(array, repeats, axis=0): array = np.expand_dims(array, axis) array = np.repeat(array, repeats, axis) return np.reshape(array, (*array.shape[:axis], -1, *array.shape[axis + 2:])) print("start generating transitions") random_actions = all_labels[np.random.choice(len(all_labels), len(states))] y = aae.decode([states, random_actions], batch_size=1000).round().astype(np.int8) print("remove known transitions") y = set_difference(y, known_transisitons) print("shuffling") random.shuffle(y) return y