def get_store_neuron(token, T_list): 
    print("\tGet Store")
    result = {}
    for T in T_list:
        print("=" * 50 + "\nT = %d" % T)
        si1 = sample_getter.get_sample_by_one_condition(seq2seq.decoder_in_test, 
                                                        token=token, position=T, N=1000)
        si2 = sample_getter.get_sample_by_one_condition(seq2seq.decoder_in_test, 
                                                        token=token, position=T, N=1000, 
                                                        except_this_token=True)    
        sample_index = sample_getter.get_different_amount_sample([si1, si2])
        if sample_index is None or sample_index.shape[1] < 5:
            print("\tToo less samples in this condition.")
            continue  # No any sample for this condition.
        
        result[T] = {}
        state = state_getter.get_hidden_state(seq2seq, sample_index)
        for t in range(T+2):
            print("=" * 50 + "\nt = %d" % t)
            x = state[:, :, t, :]
            y = np.concatenate([np.full([x.shape[1]], 1, dtype=int), np.full([x.shape[1]], 0, dtype=int)])
            x = np.reshape(x, [-1, seq2seq.units])
            x, y = shuffle(x, y, random_state=42)
            features = call_classifier.call_recursive_rfe(x, y, max_count=target_units, one_threshold=0.5)
            result[T][t] = features
            print("features =", features)
    return result
def enable_important_neuron_one_token(token):
    with open(os.path.join(saved_path, 'neuron_token=%d.pickle' % token), 'rb') as handle:
        result = pickle.load(handle)
        store_neuron = result['store']
        counter_neuron = result['counter']
        ig_neuron = result['ig']
        
    for T in ig_neuron:
        si1 = sample_getter.get_sample_by_one_condition(seq2seq.decoder_in_test, 
                                                        token=token, position=T, N=100)
        sample = seq2seq.encoder_in_test[si1]
        real = evaluator.get_evaluate_real(seq2seq, si1)
        verify_original_model(seq2seq, sample, real, token, T)
        for t in ig_neuron[T]:
            important_store = get_intersection(store_neuron[T][t], ig_neuron[T][t])
            important_counter = get_intersection(counter_neuron[T], ig_neuron[T][t])
            important_neuron = list(set(important_store).union(set(important_counter)))
            print("T=%d, t=%d, (%d, %d, %d)" % (T, t, len(important_neuron), 
                                                len(important_store), len(important_counter)))            
            print("\t", end="")
            pred = verification.verify_decoder(seq2seq, sample, important_neuron, time_step=t, 
                                           mode="enable", replace_by="zero", verbose=2)
            evaluator.evaluate_token(pred, token, T + 3)
            print("\t\t\t", end="")
            evaluator.evaluate_autoencoder_at_time(real, pred, time_step=T, verbose=2)
            #print(pred[:3])
            print("\t", end="")
            pred = verification.verify_decoder(seq2seq, sample, important_neuron, time_step=t, 
                                           mode="disable", replace_by="zero", verbose=2)
            evaluator.evaluate_token(pred, token, T + 3)
            print("\t\t\t", end="")
            evaluator.evaluate_autoencoder_at_time(real, pred, time_step=T, verbose=2)
def verify_store_one_token(token):
    with open(os.path.join(saved_path, 'neuron_token=%d.pickle' % token), 'rb') as handle:
        result = pickle.load(handle)
        store_neuron = result['store']
        ig_neuron = result['ig']
        
    for T in ig_neuron:
        print("-" * 50)
        si1 = sample_getter.get_sample_by_one_condition(seq2seq.decoder_in_test, 
                                                        token=token, position=T, N=100)
        sample = seq2seq.encoder_in_test[si1]
        real = evaluator.get_evaluate_real(seq2seq, si1)
        verify_original_model(seq2seq, sample, real, token, T)
        for t in ig_neuron[T]:
            verify_store_one_step(T, t, store_neuron[T][t], ig_neuron[T][t], seq2seq, sample, real)
def get_ig_neuron(token, T_list):
    print("\tGet IG")
    result = {}
    for T in T_list:
        print("=" * 50 + "\nT = %d" % T)
        si1 = sample_getter.get_sample_by_one_condition(seq2seq.decoder_in_test, 
                                                        token=token, position=T, N=1000)
        result[T] = {}
        decoder_states, decoder_inputs = call_ig.get_state_by_sample_index(seq2seq, si1)
        for t in range(T+1):
            print("=" * 50 + "\nt = %d" % t)
            decoder_model = call_ig.get_model_without_argmax(seq2seq, input_t=t, output_t=T)
            score = call_ig.compute_ig_steps(decoder_model, decoder_states[t], decoder_inputs, target_class=token)
            selected = call_ig.get_important_neurons_by_IG(score, k=target_units)
            result[T][t] = selected
            print("\tselected =", selected)
    return result
def test_get_stable_neuron(token=3, T=3):
    # For token = 3, T = 3
    import sample_getter
    import state_getter

    si1 = sample_getter.get_sample_by_one_condition(seq2seq.decoder_in_test,
                                                    token=token,
                                                    position=T,
                                                    N=10)
    sample_index = sample_getter.get_same_amount_sample([si1])
    state = state_getter.get_hidden_state(seq2seq, sample_index)
    for t1 in result3[T]:
        t2 = t1 + 1
        if t2 not in result3[T]: break
        for threshold in [0.0, 0.01, 0.05, 0.1, 0.5]:
            get_stable_neuron(state, result3[T], t1, t2, threshold=threshold)
        print("-" * 10)
def get_counter_neuron(token, T_list):
    print("\tGet Counter")
    result = {}
    for T in T_list:
        print("=" * 50 + "\nT = %d" % T)
        si = sample_getter.get_sample_by_one_condition(seq2seq.decoder_in_test, 
                                                       token=token, position=T, N=1000)
        sample_index = sample_getter.get_different_amount_sample([si])
        if sample_index is None or sample_index.shape[1] < 5:
            print("\tToo less samples in this condition.")
            return  # No any sample for this condition.
        state = state_getter.get_hidden_state(seq2seq, sample_index)
        state = state[:, :, :T]

        x = state[0].transpose([1, 0, 2])  # [N, t, units] -> [t, N, units]
        y = np.full([x.shape[1]], 0, dtype=int)
        for t in range(1, x.shape[0]):
            y = np.concatenate([y, np.full([x.shape[1]], t, dtype=int)])  
        x = np.reshape(x, [-1, seq2seq.units])
        x, y = shuffle(x, y, random_state=42)
        result[T] = call_classifier.call_recursive_rfe(x, y, max_count=target_units, one_threshold=0.5)
    return result