예제 #1
0
파일: network.py 프로젝트: VIM-Lab/AVSMVR
    def step(self, image, real_voxel):
        with tf.GradientTape() as image_enc_tape, tf.GradientTape() as gen_tape:
            self.state = self.image_encoder(image[:, 0, :, :, :], training = True)
            used= [0]

            for _t in range(params.time_steps - 1):
                prob = self.actor(self.state).numpy()

                action = utils.choose_action(prob, used)
                used.append(action)

                append_state = self.image_encoder(image[:, action, :, :, :], training = True)
                self.state = tf.reduce_max(tf.stack([self.state, append_state], axis = 1), axis = 1)

            voxel = self.generator(self.state, training = True)

            fin_loss = self.fin_loss(real_voxel, voxel)

        tf.print('fin_loss = ', fin_loss)

        image_enc_grad = image_enc_tape.gradient(fin_loss, self.image_encoder.trainable_variables)
        gen_grad = gen_tape.gradient(fin_loss, self.generator.trainable_variables)

        self.image_encoder_optimizer.apply_gradients(zip(image_enc_grad, self.image_encoder.trainable_variables))
        self.generator_optimizer.apply_gradients(zip(gen_grad, self.generator.trainable_variables))
예제 #2
0
파일: network.py 프로젝트: VIM-Lab/AVSMVR
    def validate(self, image, y, name):
        # temp = tf.map_fn(lambda x : self.image_encoder(x, training = False), image)

        self.state = self.image_encoder(image[:, 0, :, :, :], training = False) # shape = [batch_size, 1024]
        used = [0]
        
        for _t in range(params.time_steps - 1):
            prob = self.actor(self.state).numpy()

            action = utils.choose_action(prob, used)
            used.append(action)

            append_state = self.image_encoder(image[:, action, :, :, :], training = False)
            self.state = tf.reduce_max(tf.stack([self.state, append_state], axis = 1), axis = 1)

        voxel = self.generator(self.state, training = False)

        voxel = utils.dicide_voxel(voxel)
        utils.save_voxel(voxel, '{}_pridict'.format(name))
        utils.save_voxel(y, '{}_true'.format(name))

        y = y[0]
        y = np.argmax(y, -1)
        voxel = np.argmax(voxel, -1)

        iou = utils.cal_iou(y, voxel)
        print('iou = ', iou)
예제 #3
0
파일: main.py 프로젝트: kamahori/ESP-RL
def explain():
    F_network.load_state_dict(torch.load(F_net_name))
    C_network.load_state_dict(torch.load(C_net_name))

    state = env.reset()
    done = False

    spec = gridspec.GridSpec(ncols=2, nrows=3, height_ratios=[3, 1, 2])
    fig = plt.figure()
    ax1 = fig.add_subplot(spec[0, :])
    ax1.axis("off")
    ax2 = fig.add_subplot(spec[1, 0])
    ax2.set_ylim(0, 500)
    ax2.set_title("Left")
    ax3 = fig.add_subplot(spec[1, 1])
    ax3.set_ylim(0, 500)
    ax3.set_title("Right")
    ax4 = fig.add_subplot(spec[2, :])
    ims = []

    for t in range(1, n_timestep + 1):
        action, _ = choose_action(
            state, F_network, C_network, n_action, n_state, 0.0, device
        )
        next_state, reward, done, _ = env.step(action)

        state_action = concat_state_action(
            torch.tensor(state).unsqueeze(0), n_action, n_state, device
        ).to(torch.float)

        feature = F_network(state_action)

        ig = calc_ig(feature[0], feature[1])
        feature = feature.to("cpu").detach().numpy()

        rend = env.render(mode="rgb_array")
        im = [ax1.imshow(rend)]
        im += ax2.bar(x=range(len(feature[0])), height=feature[0], color="b")
        im += ax3.bar(x=range(len(feature[1])), height=feature[1], color="r")
        im += ax4.bar(x=range(len(ig)), height=ig, color="g")
        ims.append(im)

        state = next_state

        if done:
            break

    ani = animation.ArtistAnimation(
        fig, ims, interval=100, blit=True, repeat_delay=1000
    )

    ani.save(video_name)
예제 #4
0
    def step(self, old_obs, old_reward, a_old, obs, reward):
        """
        Want to take an as soon as possible. No latency.
        So we are going to precompute a_t based on obs_t_1

        Args:
            ...
            a_old: the action taken tha resulted in obs,reward
        """
        with tf.GradientTape() as tape:
            x_old = tf.concat([old_obs, old_reward], axis=1)
            h_old = self.encoder(x_old)

            # choose action based on a prediction of the future state
            # where a_t = trans(encoder(x))
            h_approx = self.trans(h_old)
            action = utils.choose_action(self.policy(h_approx), self.temp)

            x = tf.concat([obs, old_reward], axis=1)
            h = self.encoder(x_old)
            v = self.value(tf.concat([h, action], axis=1))

            # OPTIMIZE implementation here. could write as simply predicting inputs!?
            # predict inputs at t+1 given action taken
            obs_approx = self.decoder(h_old)
            # BUG a_old, where is that coming from!?
            v_approx = self.value(tf.concat([h_old, a_old], axis=1))

            loss_d = tf.losses.mean_squared_error(obs, obs_approx)
            loss_t = tf.losses.mean_squared_error(tf.stop_gradient(h),
                                                  h_approx)
            loss_v = tf.losses.mean_squared_error(
                v_approx, reward + self.discount * tf.stop_gradient(v))

            # maximise reward: use the appxoimated reward as supervision
            loss_p_exploit = -tf.reduce_mean(v)
            # explore: do things that result in unpredictable inputs
            loss_p_explore = -loss_d - loss_t - loss_v

        loss = self.train_step(tape, loss_d, loss_t, loss_v, loss_p_exploit,
                               loss_p_explore)

        with tf.contrib.summary.record_summaries_every_n_global_steps(10):
            tf.contrib.summary.histogram('a', action)
            tf.contrib.summary.histogram('state', h)
            tf.contrib.summary.histogram('obs', obs)

        return action
예제 #5
0
파일: main.py 프로젝트: kamahori/ESP-RL
def main():
    eps = eps_start

    for i_episode in range(1, n_episode + 1):
        state = env.reset()
        done = False
        step = 0

        for t in range(1, n_timestep + 1):
            step += 1
            action, _ = choose_action(
                state, F_network, C_network, n_action, n_state, eps, device
            )
            next_state, reward, done, _ = env.step(action)
            feature = get_feature(state)
            action_vector = np.zeros(n_action)
            action_vector[action] = 1.0
            experience = experience_t(
                np.concatenate([state, action_vector]),
                reward,
                feature,
                next_state,
                done and t < n_timestep,
            )
            state = next_state
            memory.append(experience)

            if step % freq_update_network == 0 and len(memory) > batch_size:
                train()

            if t % freq_target_update == 0:
                soft_update(F_network, F_aux, tau)
                soft_update(C_network, C_aux, tau)

            eps = max(eps_end, eps - (eps_start - eps_end) / (eps_decrease_steps - 1))

            if done:
                break

        if i_episode % freq_evaluation == 0:
            evaluation()

    torch.save(F_network.state_dict(), F_net_name)
    torch.save(C_network.state_dict(), C_net_name)
예제 #6
0
파일: main.py 프로젝트: kamahori/ESP-RL
def evaluation():
    total_reward = 0
    total_GVF_loss = 0

    n_trial = 100

    for _ in range(n_trial):
        state = env.reset()
        done = False

        gt_feature = torch.zeros(n_timestep, n_feature)
        pred_feature = torch.zeros(n_timestep, n_feature)
        discounted_para = torch.zeros(n_timestep, 1)

        step = 0
        for t in range(1, n_timestep + 1):
            step += 1
            action, pred_feature_vector = choose_action(
                state, F_network, C_network, n_action, n_state, 0.0, device
            )

            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            gt_feature_vector = torch.tensor(get_feature(state))

            pred_feature[step - 1] = pred_feature_vector

            gt_feature[:step] += (
                gt_feature_vector * (GVF_discount_factor ** discounted_para)
            )[:step]

            discounted_para[:step] += 1
            state = next_state

            if done:
                break

        with torch.no_grad():
            criterion = nn.MSELoss()
            total_GVF_loss += criterion(gt_feature[:step], pred_feature[:step]).item()

    avg_reward = total_reward / n_trial
    avg_GVF_loss = total_GVF_loss / n_trial
예제 #7
0
파일: train.py 프로젝트: VIM-Lab/AVSMVR
                used = [0]

                for t in range(params.time_steps - 1):
                    enc_loss = []
                    for i in range(24):
                        loss = 999999
                        if i not in used:
                            loss = Combiner.encoder_step(x[:, i, :, :, :], y_true)
                        enc_loss.append(loss)
                    act_true = np.zeros(24)
                    act_true[np.argmin(enc_loss)] = 1
                    act_true = act_true[np.newaxis, :]

                    prob = Combiner.actor_step(Combiner.state, act_true).numpy()

                    a = utils.choose_action(prob, used)
                    used.append(a)

                    Combiner.step(x[:, a, :, :, :])

                print('epoch ', epoch, ' iteration ', iters, '  com_loss = ', Combiner.encoder_loss(y_true, Combiner.state))
                    

                if (iters % 50 == 0):
                    Combiner.validate(x, y, '{}_{}'.format(epoch, iters))
        
            Combiner.save_model()

    # 如果在最后一个训练阶段
    elif params.cur_step == 'fin':
        model = network.Fin(training = True)
예제 #8
0
def performance(doc, worker_net, manager_net=None):

    test_document = []

    score_softmax = nn.Softmax()

    predict_cluster = []
    new_cluster_num = 1
    predict_cluster.append(0)
    mid = 0
    cluster_info = {0: [0]}

    worker = worker_net
    if manager_net is None:
        manager = worker_net
    else:
        manager = manager_net

    for data in doc.rl_case_generater(shuffle=False):

        this_doc = doc

        candi_ids_all = data["candi_ids_all"]
        rl = data["rl"]

        mention_index = autograd.Variable(
            torch.from_numpy(data["mention_word_index"]).type(
                torch.cuda.LongTensor))
        mention_spans = autograd.Variable(
            torch.from_numpy(data["mention_span"]).type(
                torch.cuda.FloatTensor))
        candi_index = autograd.Variable(
            torch.from_numpy(data["candi_word_index"]).type(
                torch.cuda.LongTensor))
        candi_spans = autograd.Variable(
            torch.from_numpy(data["candi_span"]).type(torch.cuda.FloatTensor))
        pair_feature = autograd.Variable(
            torch.from_numpy(data["pair_features"]).type(
                torch.cuda.FloatTensor))
        anaphors = autograd.Variable(
            torch.from_numpy(data["pair_anaphors"]).type(
                torch.cuda.LongTensor))
        antecedents = autograd.Variable(
            torch.from_numpy(data["pair_antecedents"]).type(
                torch.cuda.LongTensor))

        anaphoricity_index = autograd.Variable(
            torch.from_numpy(data["mention_word_index"]).type(
                torch.cuda.LongTensor))
        anaphoricity_span = autograd.Variable(
            torch.from_numpy(data["mention_span"]).type(
                torch.cuda.FloatTensor))
        anaphoricity_feature = autograd.Variable(
            torch.from_numpy(data["anaphoricity_feature"]).type(
                torch.cuda.FloatTensor))

        target = data["pair_target"]
        anaphoricity_target = data["anaphoricity_target"]

        output_manager, pair_score_manager, mention_pair_representations_manager = manager.forward_all_pair(
            nnargs["word_embedding_dimention"], mention_index, mention_spans,
            candi_index, candi_spans, pair_feature, anaphors, antecedents)
        ana_output_manager, ana_score_manager, ana_pair_representations_manager = manager.forward_anaphoricity(
            nnargs["word_embedding_dimention"], anaphoricity_index,
            anaphoricity_span, anaphoricity_feature)

        mention_pair_representations_manager = autograd.Variable(
            torch.from_numpy(mention_pair_representations_manager).type(
                torch.cuda.FloatTensor))
        ana_pair_representations_manager = autograd.Variable(
            torch.from_numpy(ana_pair_representations_manager).type(
                torch.cuda.FloatTensor))

        reindex = autograd.Variable(
            torch.from_numpy(rl["reindex"]).type(torch.cuda.LongTensor))

        scores_manager = torch.transpose(
            torch.cat((pair_score_manager, ana_score_manager), 1), 0,
            1)[reindex]
        representations_manager = torch.cat(
            (mention_pair_representations_manager,
             ana_pair_representations_manager), 0)[reindex]

        output_worker, pair_score_worker, mention_pair_representations_worker = worker.forward_all_pair(
            nnargs["word_embedding_dimention"], mention_index, mention_spans,
            candi_index, candi_spans, pair_feature, anaphors, antecedents)
        ana_output_worker, ana_score_worker, ana_pair_representations_worker = worker.forward_anaphoricity(
            nnargs["word_embedding_dimention"], anaphoricity_index,
            anaphoricity_span, anaphoricity_feature)
        mention_pair_representations_worker = autograd.Variable(
            torch.from_numpy(mention_pair_representations_worker).type(
                torch.cuda.FloatTensor))
        ana_pair_representations_worker = autograd.Variable(
            torch.from_numpy(ana_pair_representations_worker).type(
                torch.cuda.FloatTensor))

        scores_worker = torch.transpose(
            torch.cat((pair_score_worker, ana_score_worker), 1), 0, 1)[reindex]
        representations_worker = torch.cat(
            (mention_pair_representations_worker,
             ana_pair_representations_worker), 0)[reindex]

        for s, e in zip(rl["starts"], rl["ends"]):
            manager_action_embeddings = representations_manager[s:e]
            worker_action_embeddings = representations_worker[s:e]
            #score = score_softmax(torch.transpose(scores_manager[s:e],0,1)).data.cpu().numpy()[0]
            #score = score_softmax(torch.transpose(scores_worker[s:e],0,1)).data.cpu().numpy()[0]
            score = F.softmax(torch.squeeze(scores_worker[s:e]),
                              dim=0).data.cpu().numpy()
            this_action = utils.choose_action(score)

            #if this_action == len(score)-1:
            #    cluster_indexs = torch.cuda.LongTensor([this_action])
            #else:
            #    should_cluster = predict_cluster[this_action]
            #    cluster_indexs = torch.cuda.LongTensor(cluster_info[should_cluster]+[this_action])

            #action_embedding_choose = torch.mean(manager_action_embeddings[cluster_indexs],0,keepdim=True)
            #similarities = torch.sum(torch.abs(worker_action_embeddings - action_embedding_choose),1)
            #similarities = similarities.data.cpu().numpy()
            #real_action = numpy.argmin(similarities)

            real_action = this_action
            if real_action == len(score) - 1:
                should_cluster = new_cluster_num
                cluster_info[should_cluster] = []
                new_cluster_num += 1
            else:
                should_cluster = predict_cluster[real_action]

            cluster_info[should_cluster].append(mid)
            predict_cluster.append(should_cluster)
            mid += 1

        if rl["end"] == True:
            ev_document = utils.get_evaluation_document(
                predict_cluster, this_doc.gold_chain[rl["did"]], candi_ids_all,
                new_cluster_num)
            test_document.append(ev_document)
            predict_cluster = []
            new_cluster_num = 1
            predict_cluster.append(0)
            cluster_info = {0: [0]}
            mid = 0

    metrics = evaluation.Output_Result(test_document)
    r, p, f = metrics["muc"]
    print "MUC: recall: %f precision: %f  f1: %f" % (r, p, f)
    r, p, f = metrics["b3"]
    print "B3: recall: %f precision: %f  f1: %f" % (r, p, f)
    r, p, f = metrics["ceaf"]
    print "CEAF: recall: %f precision: %f  f1: %f" % (r, p, f)
    print "AVE", metrics["average"]

    return metrics
예제 #9
0
파일: reinforce.py 프로젝트: yqy/acl2018
def main():

    DIR = args.DIR
    embedding_file = args.embedding_dir

    best_network_file = "./model/network_model_pretrain.best.top"
    print >> sys.stderr, "Read model from ", best_network_file
    best_network_model = torch.load(best_network_file)

    embedding_matrix = numpy.load(embedding_file)
    "Building torch model"
    worker = network.Network(
        nnargs["pair_feature_dimention"], nnargs["mention_feature_dimention"],
        nnargs["word_embedding_dimention"], nnargs["span_dimention"], 1000,
        nnargs["embedding_size"], nnargs["embedding_dimention"],
        embedding_matrix).cuda()
    net_copy(worker, best_network_model)

    best_network_file = "./model/network_model_pretrain.best.top"
    print >> sys.stderr, "Read model from ", best_network_file
    best_network_model = torch.load(best_network_file)

    manager = network.Network(
        nnargs["pair_feature_dimention"], nnargs["mention_feature_dimention"],
        nnargs["word_embedding_dimention"], nnargs["span_dimention"], 1000,
        nnargs["embedding_size"], nnargs["embedding_dimention"],
        embedding_matrix).cuda()
    net_copy(manager, best_network_model)

    reduced = ""
    if args.reduced == 1:
        reduced = "_reduced"

    print >> sys.stderr, "prepare data for train ..."
    train_docs_iter = DataReader.DataGnerater("train" + reduced)
    #train_docs_iter = DataReader.DataGnerater("dev"+reduced)
    print >> sys.stderr, "prepare data for dev and test ..."
    dev_docs_iter = DataReader.DataGnerater("dev" + reduced)
    test_docs_iter = DataReader.DataGnerater("test" + reduced)
    '''
    print "Performance after pretraining..."
    print "DEV"
    metric = performance.performance(dev_docs_iter,worker,manager) 
    print "Average:",metric["average"]
    print "TEST"
    metric = performance.performance(test_docs_iter,worker,manager) 
    print "Average:",metric["average"]
    print "***"
    print
    sys.stdout.flush()
    '''

    lr = nnargs["lr"]
    top_k = nnargs["top_k"]

    model_save_dir = "./model/reinforce/"
    utils.mkdir(model_save_dir)

    score_softmax = nn.Softmax()

    optimizer_manager = optim.RMSprop(manager.parameters(), lr=lr, eps=1e-6)
    optimizer_worker = optim.RMSprop(worker.parameters(), lr=lr, eps=1e-6)

    MAX_AVE = 2048

    for echo in range(nnargs["epoch"]):

        start_time = timeit.default_timer()
        print "Pretrain Epoch:", echo

        reward_log = Logger(Tensorboard + args.tb +
                            "/acl2018/%d/reward/" % echo,
                            flush_secs=3)
        entropy_log_manager = Logger(Tensorboard + args.tb +
                                     "/acl2018/%d/entropy/worker" % echo,
                                     flush_secs=3)
        entropy_log_worker = Logger(Tensorboard + args.tb +
                                    "/acl2018/%d/entropy/manager" % echo,
                                    flush_secs=3)

        train_docs = utils.load_pickle(args.DOCUMENT + 'train_docs.pkl')
        #train_docs = utils.load_pickle(args.DOCUMENT + 'dev_docs.pkl')
        docs_by_id = {doc.did: doc for doc in train_docs}

        ave_reward = []
        ave_manager_entropy = []
        ave_worker_entropy = []

        print >> sys.stderr, "Link docs ..."
        tmp_data = []
        cluster_info = {0: [0]}
        cluster_list = [0]
        current_new_cluster = 1
        predict_action_embedding = []
        choose_action = []
        mid = 1

        step = 0

        statistic = {
            "worker_hits": 0,
            "manager_hits": 0,
            "total": 0,
            "manager_predict_last": 0,
            "worker_predict_last": 0
        }

        for data in train_docs_iter.rl_case_generater(shuffle=True):

            rl = data["rl"]

            scores_manager, representations_manager = get_score_representations(
                manager, data)
            scores_worker, representations_worker = get_score_representations(
                worker, data)

            for s, e in zip(rl["starts"], rl["ends"]):
                #action_embeddings = representations_manager[s:e]
                #probs = F.softmax(torch.squeeze(scores_manager[s:e]))
                action_embeddings = representations_worker[s:e]
                probs = F.softmax(torch.squeeze(
                    scores_worker[s:e])).data.cpu().numpy()

                #m = Categorical(F.softmax(torch.squeeze(scores_worker[s:e]))[:-1])
                #a = m.sample()
                #this_action = m.sample()
                #index = this_action.data.cpu().numpy()[0]

                index = utils.choose_action(probs)

                if index == (e - s - 1):
                    should_cluster = current_new_cluster
                    cluster_info[should_cluster] = []
                    current_new_cluster += 1
                else:
                    should_cluster = cluster_list[index]

                choose_action.append(index)
                cluster_info[should_cluster].append(mid)
                cluster_list.append(should_cluster)
                mid += 1

                cluster_indexs = torch.cuda.LongTensor(
                    cluster_info[should_cluster])
                action_embedding_predict_ave = torch.mean(
                    action_embeddings[cluster_indexs], 0, keepdim=True)
                action_embedding_predict_max, max_index = torch.max(
                    action_embeddings[cluster_indexs], dim=0, keepdim=True)

                action_embedding_predict = torch.cat(
                    (action_embedding_predict_ave,
                     action_embedding_predict_max), 1)
                predict_action_embedding.append(action_embedding_predict)

            tmp_data.append(data)

            if rl["end"] == True:

                inside_index = 0
                manager_path = []
                worker_path = []

                doc = docs_by_id[rl["did"]]

                for data in tmp_data:

                    rl = data["rl"]
                    pair_target = data["pair_target"]
                    anaphoricity_target = 1 - data["anaphoricity_target"]
                    target = numpy.concatenate(
                        (pair_target, anaphoricity_target))[rl["reindex"]]
                    scores_worker, representations_worker = get_score_representations(
                        worker, data)

                    for s, e in zip(rl["starts"], rl["ends"]):
                        action_embeddings = representations_worker[s:e]
                        probs = F.softmax(
                            torch.squeeze(scores_worker[s:e])
                        ).data.cpu().numpy(
                        )  #print probs.data.cpu().numpy() -> [  3.51381488e-04   9.99648571e-01]
                        action_embedding_predicted = predict_action_embedding[
                            inside_index]
                        combine_embedding = torch.cat(
                            (action_embeddings, action_embeddings), 1)
                        similarities = torch.sum(
                            torch.abs(combine_embedding -
                                      action_embedding_predicted), 1)
                        similarities = similarities.data.cpu().numpy()

                        action_probabilities = []
                        action_list = []
                        similarity_candidates = heapq.nlargest(
                            top_k, -similarities)
                        for similarity in similarity_candidates:
                            action_index = numpy.argwhere(
                                similarities == -similarity)[0][0]
                            action_probabilities.append(probs[action_index])
                            action_list.append(action_index)

                        manager_action = choose_action[inside_index]

                        if not manager_action in action_list:
                            action_list.append(manager_action)
                            action_probabilities.append(probs[manager_action])
                        sample_action = utils.sample_action(
                            numpy.array(action_probabilities))
                        worker_action = action_list[sample_action]

                        this_target = target[s:e]

                        if this_target[worker_action] == 1:
                            statistic["worker_hits"] += 1
                        if this_target[manager_action] == 1:
                            statistic["manager_hits"] += 1
                        if worker_action == (e - s - 1):
                            statistic["worker_predict_last"] += 1
                        if manager_action == (e - s - 1):
                            statistic["manager_predict_last"] += 1
                        statistic["total"] += 1

                        inside_index += 1

                        #link = manager_action
                        link = worker_action
                        m1, m2 = rl['ids'][s + link]
                        doc.link(m1, m2)

                        manager_path.append(manager_action)
                        worker_path.append(worker_action)

                reward = doc.get_f1()
                for data in tmp_data:
                    for s, e in zip(rl["starts"], rl["ends"]):
                        ids = rl['ids'][s:e]
                        ana = ids[0, 1]
                        old_ant = doc.ana_to_ant[ana]
                        doc.unlink(ana)
                        costs = rl['costs'][s:e]
                        for ant_ind in range(e - s):
                            costs[ant_ind] = doc.link(ids[ant_ind, 0],
                                                      ana,
                                                      hypothetical=True,
                                                      beta=1)
                        doc.link(old_ant, ana)
                        #costs = autograd.Variable(torch.from_numpy(costs).type(torch.cuda.FloatTensor))

                inside_index = 0
                worker_entropy = 0.0
                for data in tmp_data:
                    new_step = step
                    # worker
                    scores_worker, representations_worker = get_score_representations(
                        worker, data, dropout=nnargs["dropout_rate"])
                    optimizer_worker.zero_grad
                    worker_loss = None
                    for s, e in zip(rl["starts"], rl["ends"]):
                        costs = rl['costs'][s:e]
                        costs = autograd.Variable(
                            torch.from_numpy(costs).type(
                                torch.cuda.FloatTensor))
                        action = worker_path[inside_index]
                        score = F.softmax(torch.squeeze(scores_worker[s:e]))
                        if not score.size() == costs.size():
                            continue

                        baseline = torch.sum(costs * score)
                        this_cost = torch.log(
                            score[action]) * -1.0 * (reward - baseline)

                        if worker_loss is None:
                            worker_loss = this_cost
                        else:
                            worker_loss += this_cost
                        worker_entropy += torch.sum(
                            score * torch.log(score + 1e-7)
                        ).data.cpu().numpy()[
                            0]  #+ 0.001*torch.sum(score*torch.log(score+1e-7))
                        inside_index += 1

                    worker_loss.backward()
                    torch.nn.utils.clip_grad_norm(worker.parameters(),
                                                  nnargs["clip"])
                    optimizer_worker.step()

                    ave_worker_entropy.append(worker_entropy)
                    if len(ave_worker_entropy) >= MAX_AVE:
                        ave_worker_entropy = ave_worker_entropy[1:]
                    entropy_log_worker.log_value(
                        'entropy',
                        float(sum(ave_worker_entropy)) /
                        float(len(ave_worker_entropy)), new_step)
                    new_step += 1

                inside_index = 0
                manager_entropy = 0.0
                for data in tmp_data:
                    new_step = step
                    rl = data["rl"]

                    ave_reward.append(reward)
                    if len(ave_reward) >= MAX_AVE:
                        ave_reward = ave_reward[1:]
                    reward_log.log_value(
                        'reward',
                        float(sum(ave_reward)) / float(len(ave_reward)),
                        new_step)

                    scores_manager, representations_manager = get_score_representations(
                        manager, data, dropout=nnargs["dropout_rate"])

                    #optimizer_manager.zero_grad
                    #manager_loss = None
                    for s, e in zip(rl["starts"], rl["ends"]):
                        #costs = rl['costs'][s:e]
                        #costs = autograd.Variable(torch.from_numpy(costs).type(torch.cuda.FloatTensor))
                        score = F.softmax(torch.squeeze(scores_manager[s:e]))
                        action = manager_path[inside_index]

                        if not score.size() == costs.size():
                            continue

                        #baseline = torch.sum(costs*score)

                        #this_cost = torch.log(score[action])*-1.0*(reward-baseline)# + 0.001*torch.sum(score*torch.log(score+1e-7))
                        #if manager_loss is None:
                        #    manager_loss = this_cost
                        #else:
                        #    manager_loss += this_cost

                        manager_entropy += torch.sum(
                            score *
                            torch.log(score + 1e-7)).data.cpu().numpy()[0]
                        inside_index += 1

                    #manager_loss.backward()
                    #torch.nn.utils.clip_grad_norm(manager.parameters(), nnargs["clip"])
                    #optimizer_manager.step()

                    ave_manager_entropy.append(manager_entropy)
                    if len(ave_manager_entropy) >= MAX_AVE:
                        ave_manager_entropy = ave_manager_entropy[1:]
                    entropy_log_manager.log_value(
                        'entropy',
                        float(sum(ave_manager_entropy)) /
                        float(len(ave_manager_entropy)), new_step)
                    new_step += 1

                step = new_step
                tmp_data = []
                cluster_info = {0: [0]}
                cluster_list = [0]
                current_new_cluster = 1
                mid = 1
                predict_action_embedding = []
                choose_action = []

        end_time = timeit.default_timer()
        print >> sys.stderr, "TRAINING Use %.3f seconds" % (end_time -
                                                            start_time)
        print >> sys.stderr, "save model ..."
        #print "Top k",top_k
        print "Worker Hits", statistic[
            "worker_hits"], "Manager Hits", statistic[
                "manager_hits"], "Total", statistic["total"]
        print "Worker predict last", statistic[
            "worker_predict_last"], "Manager predict last", statistic[
                "manager_predict_last"]
        #torch.save(network_model, model_save_dir+"network_model_rl_worker.%d"%echo)
        #torch.save(ana_network, model_save_dir+"network_model_rl_manager.%d"%echo)

        print "DEV"
        metric = performance.performance(dev_docs_iter, worker, manager)
        print "Average:", metric["average"]
        #print "DEV manager"
        #metric = performance_manager.performance(dev_docs_iter,worker,manager)
        #print "Average:",metric["average"]
        print "TEST"
        metric = performance.performance(test_docs_iter, worker, manager)
        print "Average:", metric["average"]
        print
        sys.stdout.flush()
예제 #10
0
def main():
    if os.path.exists(config.use_output_path):
        os.system('rm ' + config.use_output_path)
    with open(config.use_output_path, 'a') as g:
        g.write(str(config) + '\n\n')
    sim = config.sim
    # sta_vec=list(np.zeros([config.num_steps-1]))
    config.shuffle = False
    #original sentence input
    use_data = dataset_str(config.use_data_path)
    config.batch_size = 1
    step_size = config.step_size

    start_time = time.time()
    proposal_cnt = 0
    accept_cnt = 0
    all_samples = []
    all_acc_samples = []
    all_chosen_samples = []
    for sen_id in range(use_data.length):
        sent_ids = use_data.token_ids[sen_id]
        keys = use_data.keys[sen_id]
        searcher = ConstraintSearch(keys)
        sequence_length = len(sent_ids)
        #generate for each sentence
        sta_vec = np.zeros(sequence_length)
        input_ids = np.array(sent_ids)
        input_original = use_data.tokens[sen_id]
        prev_inds = []
        old_prob = def_sent_scorer(tokenizer.decode(input_ids))
        old_prob_pen = penalty_constraint(
            searcher.count_unsafisfied_constraint(
                searcher.sent2tag(input_ids)))
        if sim != None:
            old_prob *= similarity(input_ids, input_original, sta_vec)

        outputs = []
        output_p = []
        for iter in range(config.sample_time):
            pos_set = np.array(
                get_sample_positions(sequence_length, prev_inds, step_size))
            prev_inds = pos_set
            proposal_cnt += 1

            search_cands, constr_num = searcher.search_template(
                input_ids, pos_set)
            group_prob = 1.0
            new_prob_pen = penalty_constraint(constr_num)
            original_temp = searcher.sent2tag(input_ids)
            original_constr_num = searcher.count_unsafisfied_constraint(
                original_temp)
            input_ids_old = np.array(input_ids)
            if len(search_cands) == 0:
                print('No candidate satisfies constraints. Continue.', pos_set)
            else:
                candidates = []
                candidate_probs = []
                for cand_template, action_set in search_cands:
                    masked_sent, adjusted_pos_set = mask_sentence(
                        input_ids, pos_set, action_set)
                    proposal_prob, input_ids_tmp = eval_template(
                        searcher,
                        input_original,
                        cand_template,
                        masked_sent,
                        adjusted_pos_set,
                        action_set,
                        sim=None)
                    input_text_tmp = tokenizer.decode(input_ids_tmp)
                    new_prob = def_sent_scorer(input_text_tmp)
                    if sim != None:
                        sim_constr = similarity(input_ids_tmp, input_original,
                                                sta_vec)
                        new_prob *= sim_constr
                    candidates.append(
                        (input_ids_tmp, proposal_prob, cand_template,
                         action_set, adjusted_pos_set))
                    candidate_probs.append(new_prob)

                candidate_probs_norm = normalize(np.array(candidate_probs))
                cand_idx = sample_from_candidate(
                    np.array(candidate_probs_norm))
                input_ids_tmp, proposal_prob, cand_template, action_set, adjusted_pos_set = candidates[
                    cand_idx]
                new_prob = candidate_probs[cand_idx]
                input_ids_new = np.array(input_ids_tmp)
                new_pos_set = np.array(adjusted_pos_set)
                print(cand_template)
                print(
                    tokenizer.decode(input_ids_new).encode('utf8',
                                                           errors='ignore'))

                # evaluate reverse proposal
                reverse_action_set = get_reverse_action_set(action_set)
                reverse_search_cands, reverse_min_constr_num, = searcher.search_template(
                    input_ids_new, new_pos_set, prune=False)
                reverse_group_prob = penalty_constraint(original_constr_num -
                                                        reverse_min_constr_num)
                reverse_search_cands_pruned = [(x[0], x[2])
                                               for x in reverse_search_cands
                                               if x[1] == original_constr_num]

                # check reverse search cand
                reverse_search_cand_str = [
                    ','.join(x[0]) for x in reverse_search_cands
                ]
                original_temp_str = ','.join(original_temp)
                if original_temp_str not in reverse_search_cand_str:
                    print('Warning', original_temp, cand_template, pos_set,
                          action_set, new_pos_set)
                if len(reverse_search_cands_pruned) == 0:
                    print('Warning')
                    reverse_search_cands_pruned = [original_temp]

                # evaluate reverse_candidate_probs_norm
                reverse_cand_idx = -1
                reverse_candidate_probs = []
                for c_idx, (reverse_cand_template, r_action_set
                            ) in enumerate(reverse_search_cands_pruned):
                    if ','.join(reverse_cand_template) == original_temp_str:
                        reverse_candidate_probs.append(old_prob)
                        reverse_cand_idx = c_idx
                    else:
                        masked_sent, new_adjusted_pos_set = mask_sentence(
                            input_ids_new, new_pos_set, r_action_set)
                        _, r_input_ids_tmp = eval_template(
                            searcher,
                            input_original,
                            reverse_cand_template,
                            masked_sent,
                            new_adjusted_pos_set,
                            r_action_set,
                            sim=None)
                        r_input_text_tmp = tokenizer.decode(r_input_ids_tmp)
                        r_new_prob = def_sent_scorer(r_input_text_tmp)
                        if sim != None:
                            sim_constr = similarity(input_ids_tmp,
                                                    input_original, sta_vec)
                            r_new_prob *= sim_constr
                        # candidates.append((input_ids_tmp, proposal_prob))
                        reverse_candidate_probs.append(r_new_prob)
                reverse_candidate_probs_norm = normalize(
                    np.array(reverse_candidate_probs))

                # evaluate proposal_prob_reverse
                r_masked_sent, pos_set_ = mask_sentence(
                    input_ids_new, new_pos_set, reverse_action_set)
                assert (pos_set == pos_set_).all()
                proposal_prob_reverse, input_ids_tmp_0 = \
                 eval_reverse_proposal(input_original, r_masked_sent, input_ids_old, pos_set, reverse_action_set, sim=None)

                if (input_ids_tmp_0 != input_ids_old).any():
                    print('Warning, ', input_ids_old, input_ids_new,
                          input_ids_tmp_0)
                assert (input_ids_tmp_0 == input_ids_old).all()

                # decide acceptance
                sequence_length_new = len(input_ids_new)
                input_text_new = tokenizer.decode(input_ids_new)
                if proposal_prob == 0.0 or old_prob == 0.0:
                    alpha_star = 1.0
                else:
                    alpha_star = (comb(sequence_length_new, 3) * proposal_prob_reverse * reverse_group_prob *
                                  reverse_candidate_probs_norm[reverse_cand_idx] * new_prob * new_prob_pen) / \
                                 (comb(sequence_length, 3) * proposal_prob  * group_prob *
                                  candidate_probs_norm[cand_idx] * old_prob * old_prob_pen)
                alpha = min(1, alpha_star)

                all_samples.append([
                    input_text_new, new_prob * new_prob_pen, new_prob,
                    constr_num,
                    bert_scorer.sent_score(input_ids_new, log_prob=True),
                    gpt2_scorer.sent_score(input_text_new, ppl=True)
                ])
                if tokenizer.decode(input_ids_new) not in output_p:
                    outputs.append(all_samples[-1])
                if outputs != []:
                    output_p.append(outputs[-1][0])
                print(alpha, old_prob, proposal_prob, new_prob,
                      new_prob * new_prob_pen, proposal_prob_reverse)
                if choose_action([
                        alpha, 1 - alpha
                ]) == 0 and (new_prob > old_prob * config.threshold
                             or just_acc() == 0):
                    if tokenizer.decode(input_ids_new) != tokenizer.decode(
                            input_ids):
                        accept_cnt += 1
                        print('Accept')
                        all_acc_samples.append(all_samples[-1])
                    input_ids = input_ids_new
                    sequence_length = sequence_length_new
                    assert sequence_length == len(input_ids)
                    old_prob = new_prob
                print('')

        # choose output from samples
        for num in range(config.min_length, 0, -1):
            outputss = [x for x in outputs if len(x[0].split()) >= num]
            print(num, outputss)
            if outputss != []:
                break
        if outputss == []:
            outputss.append([tokenizer.decode(input_ids), 0])
        outputss = sorted(outputss, key=lambda x: x[1])[::-1]
        with open(config.use_output_path, 'a') as g:
            g.write(outputss[0][0] + '\t' + str(outputss[0][1]) + '\n')
        all_chosen_samples.append(outputss[0])

        print('Sentence %d, used time %.2f\n' %
              (sen_id, time.time() - start_time))
    print(proposal_cnt, accept_cnt, float(accept_cnt / proposal_cnt))

    print("All samples:")
    all_samples_ = list(zip(*all_samples))
    for metric in all_samples_[1:]:
        print(np.mean(np.array(metric)))

    print("All accepted samples:")
    all_samples_ = list(zip(*all_acc_samples))
    for metric in all_samples_[1:]:
        print(np.mean(np.array(metric)))

    print("All chosen samples:")
    all_samples_ = list(zip(*all_chosen_samples))
    for metric in all_samples_[1:]:
        print(np.mean(np.array(metric)))

    with open(config.use_output_path + '-result.csv', 'w', newline='') as f:
        csv_writer = csv.writer(f, delimiter='\t')
        csv_writer.writerow(
            ['Sentence', 'Prob_sim', 'Constraint_num', 'Log_prob', 'PPL'])
        csv_writer.writerows(all_samples)
예제 #11
0
파일: learner.py 프로젝트: act65/team-barca
    def get_losses(self, old_obs, old_reward, old_action, obs, action, reward):
        """
        Can be used with/wo eager.

        Args:
            obs_new (tf.tensor): input from t+1.
                shape [batch, n_inputs] dtype tf.float32
            obs_old (tf.tensor): input from t.
                shape [batch, n_inputs] dtype tf.float32
            reward (tf.tensor): reward recieved at t
                shape [batch, 1] dtype tf.float32
            action (tf.tensor): the action taken at t
                shape [batch, n_outputs] dtype tf.float32

        High level pattern.
        - Use inputs (obs_t, r_t) to build an internal state representation.
        - Use internal state at t to predict inputs at t+1 (obs_t+1, r_t+1).
        - Use the learned v(s, a), t(s, a) to evaluate actions chosen

        Returns:
            (tuple): transition_loss, value_loss, policy_loss
        """
        # TODO would like to see a graph of this part. just for sanity

        # TODO want enc to be recurrent and recieve;
        # the old action taken and the old reward recieved
        x_old = tf.concat([old_obs, old_reward, old_action], axis=1)
        h_old = self.encoder(x_old)
        x = tf.concat([obs, reward, old_action], axis=1)
        h = self.encoder(x)

        # need differentiable actions.
        a = utils.choose_action(
            self.policy(h_old),
            self.temp)  # it bugs me that I need to recompute this
        a_new = utils.choose_action(self.policy(h), self.temp)
        # NOTE PROBLEM. old_action is not differentiable so cannot get
        # grads to the true action chosen. instead of old_action use a
        # should work out in expectation, but will just be slower learning for now
        # solution is ??? online learning, partial evaluation, predict given the dist

        v_old = self.value(tf.concat([h_old, a], axis=1))
        v = self.value(tf.concat([h, a_new], axis=1))

        # predict inputs at t+1 given action taken
        y = self.trans(tf.concat([h_old, a], axis=1))

        loss_t = tf.losses.mean_squared_error(x, y)
        loss_v = tf.losses.mean_squared_error(
            v_old, reward + self.discount * tf.stop_gradient(v))

        # maximise reward: use the appxoimated reward as supervision
        loss_p_exploit = -tf.reduce_mean(v)
        # explore: do things that result in unpredictable inputs
        loss_p_explore = -loss_t - loss_v
        # NOTE no gradients propagate back throughthe policy to the enc.
        # good or bad? good. the losses are just the inverses so good?
        # NOTE not sure it makes sense to train the same fn on both loss_p_explore
        # and loss_p_exploit???

        # # A3C: policy gradients with learned variance adjustment
        # A = 1+tf.stop_gradient(v_old - reward+self.discount*v)
        # p = tf.concat([tf.reshape(dis.prob(a_dis), [-1, 1]), cts.prob(a_cts)], axis=1)
        # loss_a = tf.reduce_mean(-tf.log(p)*A)

        return loss_t, loss_v, loss_p_exploit, loss_p_explore
예제 #12
0
def main():
    if os.path.exists(config.use_output_path):
        os.system('rm ' + config.use_output_path)
    with open(config.use_output_path, 'a') as g:
        g.write(str(config) + '\n\n')
    # for item in config.record_time:
    # 	if os.path.exists(config.use_output_path + str(item)):
    # 		os.system('rm ' + config.use_output_path + str(item))
    #CGMH sampling for paraphrase
    sim = config.sim
    # sta_vec=list(np.zeros([config.num_steps-1]))
    config.shuffle = False
    #original sentence input
    use_data = dataset_str(config.use_data_path)
    config.batch_size = 1
    step_size = config.step_size

    start_time = time.time()
    proposal_cnt = 0
    accept_cnt = 0
    all_samples = []
    all_acc_samples = []
    all_chosen_samples = []
    for sen_id in range(use_data.length):
        sent_ids = use_data.token_ids[sen_id]
        keys = use_data.keys[sen_id]
        searcher = ConstraintSearch(keys)
        sequence_length = len(sent_ids)
        #generate for each sentence
        sta_vec = np.zeros(sequence_length)
        input_ids = np.array(sent_ids)
        input_original = use_data.tokens[sen_id]
        prev_inds = []
        old_prob = def_sent_scorer(tokenizer.decode(input_ids))
        old_prob *= penalty_constraint(
            searcher.count_unsafisfied_constraint(
                searcher.sent2tag(input_ids)))
        if sim != None:
            old_prob *= similarity(input_ids, input_original, sta_vec)

        outputs = []
        output_p = []
        for iter in range(config.sample_time):
            # if iter in config.record_time:
            # 	with open(config.use_output_path, 'a', encoding='utf-8') as g:
            # 		g.write(bert_scorer.tokenizer.decode(input_ids)+'\n')
            # print(bert_scorer.tokenizer.decode(input_ids).encode('utf8', errors='ignore'))
            pos_set = get_sample_positions(sequence_length, prev_inds,
                                           step_size)
            action_set = [
                choose_action(config.action_prob) for i in range(len(pos_set))
            ]
            # if not check_constraint(input_ids):
            # 	if 0 not in pos_set:
            # 		pos_set[-1] = 0
            keep_non = config.keep_non
            masked_sent, adjusted_pos_set = mask_sentence(
                input_ids, pos_set, action_set)
            prev_inds = pos_set

            proposal_prob = 1.0  # Q(x'|x)
            proposal_prob_reverse = 1.0  # Q(x|x')
            input_ids_tmp = np.array(masked_sent)  # copy
            sequence_length_tmp = sequence_length

            for step_i in range(len(pos_set)):

                ind = adjusted_pos_set[step_i]
                ind_old = pos_set[step_i]
                action = action_set[step_i]
                if config.restrict_constr:
                    if step_i == len(pos_set) - 1:
                        use_constr = True
                    else:
                        use_constr = False
                else:
                    use_constr = True
                #word replacement (action: 0)
                if action == 0:
                    prob_mask = bert_scorer.mask_score(input_ids_tmp,
                                                       ind,
                                                       mode=0)
                    input_candidate, prob_candidate, reverse_candidate_idx, _ = \
                     generate_candidate_input_with_mask(input_ids_tmp, sequence_length_tmp, ind, prob_mask, config.search_size,
                                                        old_tok=input_ids[ind_old], mode=action)
                    if sim is not None and use_constr:
                        similarity_candidate = similarity_batch(
                            input_candidate, input_original, sta_vec)
                        prob_candidate = prob_candidate * similarity_candidate
                    prob_candidate_norm = normalize(prob_candidate)
                    prob_candidate_ind = sample_from_candidate(
                        prob_candidate_norm)
                    input_ids_tmp = input_candidate[
                        prob_candidate_ind]  # changed
                    proposal_prob *= prob_candidate_norm[
                        prob_candidate_ind]  # Q(x'|x)
                    proposal_prob_reverse *= prob_candidate_norm[
                        reverse_candidate_idx]  # Q(x|x')
                    sequence_length_tmp += 0
                    print('action:0', prob_candidate_norm[prob_candidate_ind],
                          prob_candidate_norm[reverse_candidate_idx])

                #word insertion(action:1)
                if action == 1:
                    prob_mask = bert_scorer.mask_score(input_ids_tmp,
                                                       ind,
                                                       mode=0)

                    input_candidate, prob_candidate, reverse_candidate_idx, non_idx = \
                     generate_candidate_input_with_mask(input_ids_tmp, sequence_length_tmp, ind, prob_mask, config.search_size,
                                                        mode=action, old_tok=input_ids[ind_old], keep_non=keep_non)

                    if sim is not None and use_constr:
                        similarity_candidate = similarity_batch(
                            input_candidate, input_original, sta_vec)
                        prob_candidate = prob_candidate * similarity_candidate
                    prob_candidate_norm = normalize(prob_candidate)
                    prob_candidate_ind = sample_from_candidate(
                        prob_candidate_norm)
                    input_ids_tmp = input_candidate[prob_candidate_ind]
                    if prob_candidate_ind == non_idx:
                        if input_ids_tmp[-1] == PAD_IDX:
                            input_ids_tmp = input_ids_tmp[:-1]
                        print('action:1 insert non', 1.0, 1.0)
                    else:
                        proposal_prob *= prob_candidate_norm[
                            prob_candidate_ind]  # Q(x'|x)
                        proposal_prob_reverse *= 1.0  # Q(x|x'), reverse action is deleting
                        sequence_length_tmp += 1
                        print('action:1',
                              prob_candidate_norm[prob_candidate_ind], 1.0)

                #word deletion(action: 2)
                if action == 2:
                    input_ids_for_del = np.concatenate(
                        [input_ids_tmp[:ind], [MASK_IDX], input_ids_tmp[ind:]])
                    if keep_non:
                        non_cand = np.array(input_ids_for_del)
                        non_cand[ind] = input_ids[ind_old]
                        input_candidate = np.array([input_ids_tmp, non_cand])
                        prob_candidate = np.array([
                            bert_scorer.sent_score(x) for x in input_candidate
                        ])
                        non_idx = 1
                        if sim is not None and use_constr:
                            similarity_candidate = similarity_batch(
                                input_candidate, input_original, sta_vec)
                            prob_candidate = prob_candidate * similarity_candidate
                        prob_candidate_norm = normalize(prob_candidate)
                        prob_candidate_ind = sample_from_candidate(
                            prob_candidate_norm)
                        input_ids_tmp = input_candidate[prob_candidate_ind]
                    else:
                        non_idx = -1
                        prob_candidate_ind = 0
                        input_ids_tmp = input_ids_tmp  # already deleted

                    if prob_candidate_ind == non_idx:
                        print('action:2 delete non', 1.0, 1.0)
                    else:
                        # add mask, for evaluating reverse probability
                        prob_mask = bert_scorer.mask_score(input_ids_for_del,
                                                           ind,
                                                           mode=0)
                        input_candidate, prob_candidate, reverse_candidate_idx, _ = \
                         generate_candidate_input_with_mask(input_ids_for_del, sequence_length_tmp, ind, prob_mask,
                                                            config.search_size, mode=0, old_tok=input_ids[ind_old])

                        if sim != None:
                            similarity_candidate = similarity_batch(
                                input_candidate, input_original, sta_vec)
                            prob_candidate = prob_candidate * similarity_candidate
                        prob_candidate_norm = normalize(prob_candidate)

                        proposal_prob *= 1.0  # Q(x'|x)
                        proposal_prob_reverse *= prob_candidate_norm[
                            reverse_candidate_idx]  # Q(x|x'), reverse action is inserting
                        sequence_length_tmp -= 1

                        print('action:2', 1.0,
                              prob_candidate_norm[reverse_candidate_idx])

            new_prob = def_sent_scorer(tokenizer.decode(input_ids_tmp))
            new_prob *= penalty_constraint(
                searcher.count_unsafisfied_constraint(
                    searcher.sent2tag(input_ids_tmp)))
            if sim != None:
                sim_constr = similarity(input_ids_tmp, input_original, sta_vec)
                new_prob *= sim_constr
            input_text_tmp = tokenizer.decode(input_ids_tmp)
            all_samples.append([
                input_text_tmp, new_prob,
                searcher.count_unsafisfied_constraint(
                    searcher.sent2tag(input_ids_tmp)),
                bert_scorer.sent_score(input_ids_tmp, log_prob=True),
                gpt2_scorer.sent_score(input_text_tmp, ppl=True)
            ])
            if tokenizer.decode(input_ids_tmp) not in output_p:
                outputs.append(all_samples[-1])
            if outputs != []:
                output_p.append(outputs[-1][0])
            if proposal_prob == 0.0 or old_prob == 0.0:
                alpha_star = 1.0
            else:
                alpha_star = (proposal_prob_reverse *
                              new_prob) / (proposal_prob * old_prob)
            alpha = min(1, alpha_star)
            print(
                tokenizer.decode(input_ids_tmp).encode('utf8',
                                                       errors='ignore'))
            print(alpha, old_prob, proposal_prob, new_prob,
                  proposal_prob_reverse)
            proposal_cnt += 1
            if choose_action([alpha, 1 - alpha]) == 0 and (
                    new_prob > old_prob * config.threshold or just_acc() == 0):
                if tokenizer.decode(input_ids_tmp) != tokenizer.decode(
                        input_ids):
                    accept_cnt += 1
                    print('Accept')
                    all_acc_samples.append(all_samples[-1])
                input_ids = input_ids_tmp
                sequence_length = sequence_length_tmp
                old_prob = new_prob

        # choose output from samples
        for num in range(config.min_length, 0, -1):
            outputss = [x for x in outputs if len(x[0].split()) >= num]
            print(num, outputss)
            if outputss != []:
                break
        if outputss == []:
            outputss.append([tokenizer.decode(input_ids), 0])
        outputss = sorted(outputss, key=lambda x: x[1])[::-1]
        with open(config.use_output_path, 'a') as g:
            g.write(outputss[0][0] + '\t' + str(outputss[0][1]) + '\n')
        all_chosen_samples.append(outputss[0])

        print('Sentence %d, used time %.2f\n' %
              (sen_id, time.time() - start_time))
    print(proposal_cnt, accept_cnt, float(accept_cnt / proposal_cnt))

    print("All samples:")
    all_samples_ = list(zip(*all_samples))
    for metric in all_samples_[1:]:
        print(np.mean(np.array(metric)))

    print("All accepted samples:")
    all_samples_ = list(zip(*all_acc_samples))
    for metric in all_samples_[1:]:
        print(np.mean(np.array(metric)))

    print("All chosen samples:")
    all_samples_ = list(zip(*all_chosen_samples))
    for metric in all_samples_[1:]:
        print(np.mean(np.array(metric)))

    with open(config.use_output_path + '-result.csv', 'w', newline='') as f:
        csv_writer = csv.writer(f, delimiter='\t')
        csv_writer.writerow(
            ['Sentence', 'Prob_sim', 'Constraint_num', 'Log_prob', 'PPL'])
        csv_writer.writerows(all_samples)