class Dialog(object):
    """Dialogue runner."""
    def __init__(self, agents, args):
        # for now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()

    def _register_metrics(self):
        """Registers valuable metrics."""
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_average('advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_percentage('%s_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
        # text metrics
        ref_text = ' '.join(data.read_lines(self.args.ref_text))
        self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        return len(out) == 1 and out[0] == '<selection>'

    def show_metrics(self):
        return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def run(self, ctxs, logger):
        """Runs one instance of the dialogue."""
        assert len(self.agents) == len(ctxs)
        # initialize agents by feeding in the contexes
        for agent, ctx in zip(self.agents, ctxs):
            agent.feed_context(ctx)
            logger.dump_ctx(agent.name, ctx)
        logger.dump('-' * 80)

        # choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        conv = []
        # reset metrics
        self.metrics.reset()

        while True:
            # produce an utterance
            out = writer.write()

            self.metrics.record('sent_len', len(out))
            self.metrics.record('full_match', out)
            self.metrics.record('%s_unique' % writer.name, out)

            # append the utterance to the conversation
            conv.append(out)
            # make the other agent to read it
            reader.read(out)
            if not writer.human:
                logger.dump_sent(writer.name, out)
            # check if the end of the conversation was generated
            if self._is_selection(out):
                self.metrics.record('%s_sel' % writer.name, 1)
                self.metrics.record('%s_sel' % reader.name, 0)
                break
            writer, reader = reader, writer

        choices = []
        # generate choices for each of the agents
        for agent in self.agents:
            choice = agent.choose()
            choices.append(choice)
            logger.dump_choice(agent.name, choice[: self.domain.selection_length() // 2])

        print(choices)
        # evaluate the choices, produce agreement and a reward
        agree, rewards = self.domain.score_choices(choices, ctxs)
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        # perform update, in case if any of the agents is learnable
        for agent, reward in zip(self.agents, rewards):
            logger.dump_reward(agent.name, agree, reward)
            logging.debug("%s : %s : %s" % (str(agent.name), str(agree), str(rewards)))
            agent.update(agree, reward)

        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        for ctx, choice in zip(ctxs, choices):
            logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards
Exemplo n.º 2
0
class Dialog(object):
    def __init__(self, agents, args):
        # For now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()

    def _register_metrics(self):
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_moving_percentage('moving_agree')
        self.metrics.register_average('advantage')
        self.metrics.register_moving_average('moving_advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        self.metrics.register_average('agree_comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_moving_average('%s_moving_rew' % agent.name)
            self.metrics.register_average('agree_%s_rew' % agent.name)
            self.metrics.register_percentage('%s_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
        # text metrics
        if self.args.ref_text:
            ref_text = ' '.join(data.read_lines(self.args.ref_text))
            self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        """if dialog end"""
        return len(out) == 1 and (out[0] in ['<selection>', '<no_agreement>'])

    def show_metrics(self):
        return ' '.join(
            ['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def run(self, ctxs, logger, max_words=5000):
        assert len(self.agents) == len(ctxs)
        for agent, ctx, partner_ctx in zip(self.agents, ctxs, reversed(ctxs)):
            agent.feed_context(ctx)
            agent.feed_partner_context(partner_ctx)
            logger.dump_ctx(agent.name, ctx)
        logger.dump('-' * 80)

        # Choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        conv = []
        self.metrics.reset()

        #words_left = np.random.randint(50, 200)
        words_left = max_words  # max 5000 words
        length = 0
        expired = False

        turn_num = 0
        while True:
            # print('dialog turn [{}]'.format(turn_num))
            turn_num += 1
            # print('\twrite')
            out = writer.write(max_words=20)  #words_left)
            # print('\twrite done')
            words_left -= len(out)
            length += len(out)

            self.metrics.record('sent_len', len(out))
            if 'full_match' in self.metrics.metrics:
                self.metrics.record('full_match', out)
            self.metrics.record('%s_unique' % writer.name, out)

            conv.append(out)
            # print('\tread')
            reader.read(out)
            if not writer.human:
                logger.dump_sent(writer.name, out)

            if self._is_selection(out):
                self.metrics.record('%s_sel' % writer.name, 1)
                self.metrics.record('%s_sel' % reader.name, 0)
                break

            if words_left <= 1:
                break

            writer, reader = reader, writer
        # print('turn_num:{}'.format(turn_num))

        choices = []
        for agent in self.agents:
            choice = agent.choose()
            choices.append(choice)
            logger.dump_choice(agent.name,
                               choice[:self.domain.selection_length() // 2])

        agree, rewards = self.domain.score_choices(choices, ctxs)
        if expired:
            agree = False
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        for i, (agent, reward) in enumerate(zip(self.agents, rewards)):
            logger.dump_reward(agent.name, agree, reward)
            j = 1 if i == 0 else 0
            agent.update(agree,
                         reward,
                         choice=choices[i],
                         partner_choice=choices[j],
                         partner_input=ctxs[j],
                         partner_reward=rewards[j])

        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
            self.metrics.record('moving_advantage', rewards[0] - rewards[1])
            self.metrics.record('agree_comb_rew', np.sum(rewards))
            for agent, reward in zip(self.agents, rewards):
                self.metrics.record('agree_%s_rew' % agent.name, reward)

        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('moving_agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)
            self.metrics.record('%s_moving_rew' % agent.name,
                                reward if agree else 0)

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        for ctx, choice in zip(ctxs, choices):
            logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards
Exemplo n.º 3
0
class Dialog(object):
    def __init__(self, agents, args):
        # For now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()

    def _register_metrics(self):
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_average('advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_percentage('%s_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
        # text metrics
        ref_text = ' '.join(data.read_lines(self.args.ref_text))
        self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        return len(out) == 1 and out[0] == '<selection>'

    def show_metrics(self):
        return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def run(self, ctxs, logger):
        assert len(self.agents) == len(ctxs)
        for agent, ctx in zip(self.agents, ctxs):
            agent.feed_context(ctx)
            logger.dump_ctx(agent.name, ctx)
        logger.dump('-' * 80)

        # Choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        conv = []
        self.metrics.reset()

        while True:
            out = writer.write()

            self.metrics.record('sent_len', len(out))
            self.metrics.record('full_match', out)
            self.metrics.record('%s_unique' % writer.name, out)

            conv.append(out)
            reader.read(out)
            if not writer.human:
                logger.dump_sent(writer.name, out)
            if self._is_selection(out):
                self.metrics.record('%s_sel' % writer.name, 1)
                self.metrics.record('%s_sel' % reader.name, 0)
                break
            writer, reader = reader, writer


        choices = []
        for agent in self.agents:
            choice = agent.choose()
            choices.append(choice)
            logger.dump_choice(agent.name, choice[: self.domain.selection_length() // 2])

        agree, rewards = self.domain.score_choices(choices, ctxs)
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        for agent, reward in zip(self.agents, rewards):
            logger.dump_reward(agent.name, agree, reward)
            agent.update(agree, reward)

        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        for ctx, choice in zip(ctxs, choices):
            logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards
class Dialog(object):
    """Dialogue runner."""
    def __init__(self, agents, args):
        # for now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()

    def _register_metrics(self):
        """Registers valuable metrics."""
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_average('advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_percentage('%s_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
        # text metrics
        ref_text = ' '.join(data.read_lines(self.args.ref_text))
        self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        return len(out) == 1 and out[0] == '<selection>'

    def show_metrics(self):
        return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def get_loss(self, inpt, words, lang_hs, lang_h, c=1, bob_ends=True):
        bob = self.agents[1]
        inpt_emb = bob.model.get_embedding(inpt,bob.lang_h,bob.ctx_h)
        bob.read_emb(inpt_emb, inpt)
        if bob_ends:
            loss1, bob_out, _ = bob.write_selection(wb_attack=True)
            bob.words = words.copy()
            bob_choice, classify_loss, _ = bob.choose(inpt_emb=inpt_emb,wb_attack=True)
            t_loss = c*loss1 + classify_loss
        else:
            bob_out = bob.write(bob_ends)
            out = self.agents[0].write_selection()
            bob.read(out)
            bob.words = words.copy()
            bob_choice, classify_loss, _ = bob.choose(inpt_emb=inpt_emb,bob_ends=bob_ends, bob_out=bob_out, wb_attack=True)
            t_loss = classify_loss                      
        #t_loss.backward(retain_graph=True)
        bob.lang_hs = lang_hs.copy()
        bob.lang_h = lang_h.clone()
        if bob_ends:
            return t_loss.item(), loss1, classify_loss
        else:
            return t_loss.item()


    def run(self, ctxs, logger):
        """Runs one instance of the dialogue."""
        assert len(self.agents) == len(ctxs)
        # initialize agents by feeding in the contexes
        #for agent, ctx in zip(self.agents, ctxs):
        #    agent.feed_context(ctx)
        #   logger.dump_ctx(agent.name, ctx)
        self.agents[0].feed_context(ctxs[0])
        logger.dump_ctx(self.agents[0].name, ctxs[0])
        self.agents[1].feed_context(ctxs[1],ctxs[0])
        logger.dump_ctx(self.agents[1].name, ctxs[1])

        logger.dump('-' * 80)

        # choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        writer, reader = self.agents

        conv = []
        # reset metrics
        self.metrics.reset()

         #### Minhao ####
        count_turns = 0       

        while True:
            # produce an utterance
            if count_turns > self.args.max_turns-1:
                if writer == self.agents[0]:
                    inpt_emb, inpt, lang_hs, lang_h, words = writer.write_white(reader)
                    #print(writer.words[-1][0].grad)
                    ### need padding in the input_emb
                    break
                #else:

            else:
                out = writer.write()

            self.metrics.record('sent_len', len(out))
            self.metrics.record('full_match', out)
            self.metrics.record('%s_unique' % writer.name, out)

            # append the utterance to the conversation
            conv.append(out)
            # make the other agent to read it
            reader.read(out)
            if not writer.human:
                logger.dump_sent(writer.name, out)
            # check if the end of the conversation was generated
            if self._is_selection(out):
                self.metrics.record('%s_sel' % writer.name, 1)
                self.metrics.record('%s_sel' % reader.name, 0)
                break
            writer, reader = reader, writer
            count_turns += 1
            ##### add selection mark if exceeding the max_turns

        ### Minhao: need to design loss focusing on the choices
        ### No evalution in the conversation????

        bob = self.agents[1]

        choices = []
        #class_losses = []
        # generate choices for each of the agents
        #flag_loss=False
        
        ####
        # get the final loss and do the back-propagation
        c=1
        step_size = 5e-2   
        #lang_hs[0].retain_grad()
        #print(words)
        all_index_n = len(self.agents[0].model.word_dict)
        all_index = torch.LongTensor(range(all_index_n)).cuda()
        all_word_emb = self.agents[0].model.word_encoder(all_index)
        threshold = 10
        #print(all_word_emb.size())
        print(inpt_emb.size(),inpt)
        
        bob_ends = False

        #fixed_lang_h = bob.lang_h.copy()
        fixed_ctx_h = bob.ctx_h.clone() 

        if inpt_emb.size()[0]>3:
    
            iterations = 1000
            #mask= [0] * (inpt_emb.size()[0]-1)
            Flag=False
            for iter_idx in range(iterations):
                #print(inpt,len(bob.lang_hs),bob.lang_h.size())
                #print(len(bob.lang_hs))
                if (iter_idx+1)%1==0 and Flag:                                
                    inpt_emb = bob.model.get_embedding(inpt,lang_h,fixed_ctx_h)
                    #changed = False
                inpt_emb.retain_grad()
                #bob.lang_hs[-1].retain_grad()
                bob.read_emb(inpt_emb, inpt)
                #print(len(bob.lang_hs))

                if bob_ends:
                    loss1, bob_out, _ = bob.write_selection(wb_attack=True)
                else:
                    bob_out = bob.write(bob_ends)
                    #bob_out = bob._encode(bob_out, bob.model.word_dict)
        # then append the utterance
                    #
                #print(len(bob.lang_hs))
                #print(len(lang_hs))
                #print(len(lang_hs))
                if not bob_ends:
                    out = self.agents[0].write_selection()
                    bob.read(out) #????

                #print(len(bob.lang_hs))
                #if bob_ends:
                bob.words = words.copy()

                if bob_ends:
                    bob_choice, classify_loss, _ = bob.choose(inpt_emb=inpt_emb,wb_attack=True)
                else:
                    bob_choice, classify_loss, _ = bob.choose(inpt_emb=inpt_emb,bob_ends=bob_ends, bob_out=bob_out, wb_attack=True)
                #print(len(bob.lang_hs))
                if bob_ends:
                    t_loss = c*loss1 + classify_loss
                else:
                    t_loss = classify_loss

                if (iter_idx+1)%1==0:
                    if bob_ends:
                        print(t_loss.item(), loss1.item(), classify_loss.item())
                        if loss1==0.0 and classify_loss<=0.0:
                            print("get legimate adversarial example")
                            print(self.agents[0]._decode(inpt,bob.model.word_dict))      ### bug still?????
                            break
                    else:
                        print(t_loss.item())
                        if t_loss.item()<=-3.0:
                            print("get legimate adversarial example")
                            print(self.agents[0]._decode(inpt,bob.model.word_dict))      ### bug still?????
                            break
                #t_loss = loss1
                #bob.lang_hs[2].retain_grad()
                #logits.retain_grad()
                #t_loss = classify_loss
                t_loss.backward(retain_graph=True)
                #print(len(bob.lang_hs))
                
                #print(logits.grad)
                #print(t_loss.item(),loss1.item(),classify_loss.item())
            #print(inpt_emb.size())
                #print(inpt_emb.grad.size())
                inpt_emb.grad[:,:,256:] = 0
                inpt_emb.grad[0,:,:] = 0
                #print(inpt_emb.grad[2])
                #inpt_emb.grad[0][:][:]=0
                inpt_emb = inpt_emb - step_size * inpt_emb.grad
                bob.lang_hs = lang_hs.copy()
                bob.lang_h = lang_h.clone()
                bob.ctx_h = fixed_ctx_h.clone()
                # projection
                min_inpt = None
                temp_inpt = inpt.clone()
                if iter_idx%1==0:
                    for emb_idx in range(1,inpt_emb.size()[0]-1):
                    #for emb_idx in range(1,4):
                        rep_candidate = []
                        dis_a=[] 
                        for r_idx in range(1,all_index_n): # excluding <eos>
                            if r_idx==inpt[emb_idx-1].item():
                                rep_candidate.append(r_idx)
                                continue
                            dis=torch.norm(inpt_emb[emb_idx][:,:256]-all_word_emb[r_idx]).item()
                            if dis< threshold:
                                rep_candidate.append(r_idx)
                                if not dis_a:
                                    continue
                                elif dis<min(dis_a):
                                    min_idx = r_idx

                            dis_a.append(dis)
                        #print(np.argmin(dis_a),min(dis_a))
                        if rep_candidate:
                            #mask[emb_idx-1]=1
                            #temp= random.choice(rep_candidate)
                            min_loss = t_loss.item()
                            for candi in rep_candidate:
                                temp_inpt[emb_idx-1]=candi
                                if bob_ends:
                                    loss,_,_ = self.get_loss(temp_inpt, words, lang_hs, lang_h)
                                else:
                                    loss = self.get_loss(temp_inpt, words, lang_hs, lang_h, bob_ends=bob_ends)
                                if loss<min_loss:
                                    min_loss = loss
                                    min_inpt = temp_inpt.clone()
                        else:
                            continue
                    if min_inpt is not None:
                        inpt = min_inpt.clone()
                        print(inpt)
                        Flag=True
                    else:
                        Flag=False
                        #break
                #print(rep_candidate)
            #print(t_loss,lang_hs[0].grad)
            print("attack finished")
        else:
            if bob_ends:
                bob.read_emb(inpt_emb, inpt)
                _, bob_out, _ = bob.write_selection(wb_attack=True)
                bob.words = words.copy()
                bob_choice, _, _ = bob.choose(inpt_emb=inpt_emb,wb_attack=True)
            else:
                bob.read_emb(inpt_emb, inpt)
                bob_out = bob.write(bob_ends)
                out = self.agents[0].write_selection()
                bob.read(out)
                bob.words = words.copy()
                bob_choice, _, _ = bob.choose(inpt_emb=inpt_emb,bob_ends=bob_ends, bob_out=bob_out, wb_attack=True)


        if bob_ends:
            logger.dump_sent(self.agents[0].name,self.agents[0]._decode(inpt,bob.model.word_dict))
            logger.dump_sent(bob.name,['<selection>'])
        else:
            logger.dump_sent(self.agents[0].name, self.agents[0]._decode(inpt,self.agents[0].model.word_dict))
            logger.dump_sent(bob.name, bob._decode(bob_out, bob.model.word_dict))
            logger.dump_sent(self.agents[0].name, ['<selection>'])
        #####
        choices.append(bob_choice)
        alice_choice = bob_choice[:]
        for indx in range(3):
           alice_choice[indx+3], alice_choice[indx] = alice_choice[indx], alice_choice[indx+3]
        choices.append(alice_choice) ######## always agree
        choices[1], choices[0] = choices[0], choices[1]
        #print(choices)
        #for agent in self.agents:
        #    choice, class_loss = agent.choose(flag=flag_loss)
        #    class_losses.append(class_loss)
        #    choices.append(choice)
        #    logger.dump_choice(agent.name, choice[: self.domain.selection_length() // 2])
        #    flag_loss=True

        # evaluate the choices, produce agreement and a reward
        #print(choices,ctxs)
        agree, rewards = self.domain.score_choices(choices, ctxs)
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        #print(rewards)
        # perform update, in case if any of the agents is learnable
        # let the difference become new reward
        ## how to combine the loss to the reward

        '''
        diff = rewards[0]-rewards[1] 
        flag = True
        agree = 1
        #print(5 - classify_loss.item())
        for agent, reward in zip(self.agents, rewards):           
            if flag:
                logger.dump_reward(agent.name, agree, reward)
                #agent.update(agree, 50-class_losses[1].item())
                agent.update(agree, 5-classify_loss.item())
                #agent.update(agree, diff - 0.05 * class_losses[1].data[0])
                #agent.update(agree, diff)
            else:
                logger.dump_reward(agent.name, agree, reward)
                if not self.args.fixed_bob:
                    agent.update(agree, reward)
            flag=False
        '''
        agree = 1
        for agent, reward in zip(self.agents, rewards):
            logger.dump_reward(agent.name, agree, reward)
            logging.debug("%s : %s : %s" % (str(agent.name), str(agree), str(rewards)))
            #agent.update(agree, 5-classify_loss.item())


        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        for ctx, choice in zip(ctxs, choices):
            logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards
class Dialog(object):
    """Dialogue runner."""
    def __init__(self, agents, args):
        # for now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()

    def _register_metrics(self):
        """Registers valuable metrics."""
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_average('advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_percentage('%s_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
        # text metrics
        ref_text = ' '.join(data.read_lines(self.args.ref_text))
        self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        return len(out) == 1 and out[0] == '<selection>'

    def show_metrics(self):
        return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def get_loss(self, inpt, words, lang_hs, lang_h, c=1, bob_ends=True, bob_out=None):
        bob=self.agents[1]
        bob.read(inpt, f_encode=False)
        if bob_ends:
            loss1, bob_out, _ = bob.write_selection(wb_attack=True)
            bob_choice, classify_loss, _ = bob.choose()
            t_loss = c*loss1 + classify_loss
        else:
            #if bob_out is None:
            bob_out = bob.write(bob_ends)
            _,out,_ = self.agents[0].write_selection(wb_attack=True, alice=True)
            bob.read(out)
            bob_choice, classify_loss, _ = bob.choose()
            t_loss = classify_loss                      
        #t_loss.backward(retain_graph=True)
        bob.words = copy.copy(words)
        bob.lang_hs = copy.copy(lang_hs)
        bob.lang_h = lang_h.clone()
        if bob_ends:
            return t_loss.item(), loss1.item(), classify_loss.item(), bob_out, bob_choice
        else:
            return t_loss.item(), bob_out, bob_choice
            #return t_loss, None, bob_choice

    def attack(self, inpt, lang_hs, lang_h, words, bob_ends):
        bob = self.agents[1]


        #class_losses = []
        # generate choices for each of the agents
        #flag_loss=False
        
        c=1
        #print(words)
        all_index_n = len(self.agents[0].model.word_dict)
    
        #print(inpt)
        
        #fixed_lang_h = bob.lang_h.copy()
        fixed_ctx_h = bob.ctx_h.clone() 

        if True:
            iterations = 3
            #mask= [0] * (inpt_emb.size()[0]-1)
            for iter_idx in range(iterations):
                # projection
                min_inpt = None
                #temp_inpt = inpt.clone()
                min_loss_a = []
                min_inpt_a = []

                if bob_ends:
                    t_loss,loss1,classify_loss, bob_out, bob_choice = self.get_loss(inpt, words, lang_hs, lang_h)
                else:
                    #bob_out = bob.write(bob_ends)
                    t_loss, bob_out, bob_choice = self.get_loss(inpt, words, lang_hs, lang_h, bob_ends=bob_ends, bob_out=None)
                if bob_ends:
                    print(iter_idx,t_loss, loss1, classify_loss)
                else:
                    print(iter_idx,t_loss)
                if bob_ends:
                    if loss1==0.0 and t_loss<=-5.0:
                        print("get legimate adversarial example")
                        print(self.agents[0]._decode(inpt,bob.model.word_dict))      ### bug still?????
                        print("bob attack finished")
                    
                        break
                else:
                    if t_loss<=-3.0:
                        print("get legimate adversarial example")
                        print(self.agents[0]._decode(inpt,bob.model.word_dict))      ### bug still?????
                        print("alice attack finished")
                        break                    
                for emb_idx in range(1,inpt.size()[0]-1):                   
                    min_loss = t_loss
                    for candi in range(1,all_index_n):
                        temp_inpt = inpt.clone()
                        temp_inpt[emb_idx]=candi
                        if bob_ends:
                            loss,_,_,_,_= self.get_loss(temp_inpt, words, lang_hs, lang_h)
                        else:
                            #bob_out = bob.write(bob_ends)
                            loss,bob_out,_ = self.get_loss(temp_inpt, words, lang_hs, lang_h, bob_ends=bob_ends, bob_out=None)
                            if loss<0:
                                sum_loss=0
                                for _ in range(10):
                                    loss,_,_ = self.get_loss(temp_inpt, words, lang_hs, lang_h, bob_ends=bob_ends, bob_out=None)
                                    sum_loss += loss
                                    #print(loss)
                                loss = sum_loss/10
                        #if loss<0:
                        #    print("first loss",loss, "bob_choice", bob_choice, "bob_out", bob_out)
                            #print(temp_inpt,bob.words,bob.lang_hs,bob.lang_h.size())
                        #    print("sec loss",self.get_loss(temp_inpt, words, lang_hs, lang_h, bob_ends=bob_ends,bob_out=bob_out))
                            #print(temp_inpt,bob.words,bob.lang_hs,bob.lang_h.size())
                        #    print("third loss",self.get_loss(temp_inpt, words, lang_hs, lang_h, bob_ends=bob_ends,bob_out=bob_out))
                            #print(temp_inpt,bob.words,bob.lang_hs,bob.lang_h.size())
                        if loss<min_loss:
                            min_loss = loss
                            min_inpt = temp_inpt.clone()
                            #print(min_loss)
                            
    
    
                    min_loss_a.append(min_loss)
                    min_inpt_a.append(min_inpt)

                if len(min_loss_a)!=0:
                    min_idx_in_a = np.argmin(min_loss_a)
                    if min_inpt_a[min_idx_in_a] is not None:
                        inpt = min_inpt_a[min_idx_in_a].clone()
                    else:
                        print(min_inpt_a)
                    #print(min_inpt_a)
                    #print(min_loss_a)
                    #print(inpt)
                    #print(loss)


            
        #else:

            """
            if bob_ends:
                bob.read_emb(inpt_emb, inpt)
                _, bob_out, _ = bob.write_selection(wb_attack=True)
                bob.words = words.copy()
                bob_choice, _, _ = bob.choose(inpt_emb=inpt_emb,wb_attack=True)
            else:
                bob.read_emb(inpt_emb, inpt)
                bob_out = bob.write(bob_ends)
                out = self.agents[0].write_selection()
                bob.read(out)
                bob.words = words.copy()
                bob_choice, _, _ = bob.choose(inpt_emb=inpt_emb,bob_ends=bob_ends, bob_out=bob_out, wb_attack=True)
            """
        return bob_choice, bob_out, t_loss, inpt


    def run(self, ctxs, logger):
        """Runs one instance of the dialogue."""
        assert len(self.agents) == len(ctxs)
        # initialize agents by feeding in the contexes
        #for agent, ctx in zip(self.agents, ctxs):
        #    agent.feed_context(ctx)
        #   logger.dump_ctx(agent.name, ctx)
        self.agents[0].feed_context(ctxs[0])
        logger.dump_ctx(self.agents[0].name, ctxs[0])
        self.agents[1].feed_context(ctxs[1],ctxs[0])
        logger.dump_ctx(self.agents[1].name, ctxs[1])

        logger.dump('-' * 80)

        # choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        #reader, writer = self.agents

        conv = []
        # reset metrics
        self.metrics.reset()

         #### Minhao ####
        count_turns = 0       

        #bob_ends = False
        with torch.no_grad():
            while True:
                # produce an utterance
                bob_out= None
                if count_turns > self.args.max_turns:
                    print("Failed")
                    out = writer.write_selection()
                    logger.dump_sent(writer.name, out)
                    break
                if writer == self.agents[0]:
                    inpt, lang_hs, lang_h, words = writer.write_white(reader)
                    if inpt.size()[0]>3:
                        print("try to let bob select")
                        bob_ends = True
                        bob_choice, bob_out, loss, inpt = self.attack(inpt, lang_hs, lang_h, words, bob_ends)
                        #continue
                        if loss<=-5.0 and self._is_selection(bob_out):
                            break
                        else:
                            print("try to let alice select")
                            bob_ends=False
                            inpt, lang_hs, lang_h, words = writer.write_white(reader)
                            bob_choice, bob_out, loss, inpt = self.attack(inpt, lang_hs, lang_h, words, bob_ends)
                            if loss<=-2.0:
                                break
                            else:
                                print("enlong the dialogue")
                                out = writer.write()
                                #if count_turns>3:
                                #    print("using RL sentence")
                                #    out = writer.write_rl()
                                #print(out)
                    else:
                        out = writer.write()
                else:
                    out = writer.write()

                self.metrics.record('sent_len', len(out))
                self.metrics.record('full_match', out)
                self.metrics.record('%s_unique' % writer.name, out)

                # append the utterance to the conversation
                conv.append(out)
                # make the other agent to read it
                reader.read(out)
                if not writer.human:
                    logger.dump_sent(writer.name, out)
                # check if the end of the conversation was generated
                print(out)
                if self._is_selection(out):
                    self.metrics.record('%s_sel' % writer.name, 1)
                    self.metrics.record('%s_sel' % reader.name, 0)
                    break
                writer, reader = reader, writer
                count_turns += 1
            ##### add selection mark if exceeding the max_turns

        ### Minhao: need to design loss focusing on the choices
        ### No evalution in the conversation????
        #bob_ends = False
        #bob_choice, bob_out = self.attack(inpt, lang_hs, lang_h, words, bob_ends)
        bob = self.agents[1]

        if bob_out is not None:
            if bob_ends:
                logger.dump_sent(self.agents[0].name,self.agents[0]._decode(inpt,self.agents[0].model.word_dict))
                logger.dump_sent(bob.name, bob_out)
            else:
                logger.dump_sent(self.agents[0].name, self.agents[0]._decode(inpt,self.agents[0].model.word_dict))
                logger.dump_sent(bob.name, bob._decode(bob_out, bob.model.word_dict))
                logger.dump_sent(self.agents[0].name, ['<selection>'])
        else:
            bob_choice, _, _ = bob.choose()
        #####
        choices = []
        choices.append(bob_choice)
        #print(choices)
        alice_choice = bob_choice[:]
        for indx in range(3):
           alice_choice[indx+3], alice_choice[indx] = alice_choice[indx], alice_choice[indx+3]
        choices.append(alice_choice) ######## always agree
        choices[1], choices[0] = choices[0], choices[1]
        #print(choices)
        #for agent in self.agents:
        #    choice, class_loss = agent.choose(flag=flag_loss)
        #    class_losses.append(class_loss)
        #    choices.append(choice)
        #    logger.dump_choice(agent.name, choice[: self.domain.selection_length() // 2])
        #    flag_loss=True

        # evaluate the choices, produce agreement and a reward
        #print(choices,ctxs)
        agree, rewards = self.domain.score_choices(choices, ctxs)
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        #print(rewards)
        # perform update, in case if any of the agents is learnable
        # let the difference become new reward
        ## how to combine the loss to the reward

        '''
        diff = rewards[0]-rewards[1] 
        flag = True
        agree = 1
        #print(5 - classify_loss.item())
        for agent, reward in zip(self.agents, rewards):           
            if flag:
                logger.dump_reward(agent.name, agree, reward)
                #agent.update(agree, 50-class_losses[1].item())
                agent.update(agree, 5-classify_loss.item())
                #agent.update(agree, diff - 0.05 * class_losses[1].data[0])
                #agent.update(agree, diff)
            else:
                logger.dump_reward(agent.name, agree, reward)
                if not self.args.fixed_bob:
                    agent.update(agree, reward)
            flag=False
        '''
        agree = 1
        for agent, reward in zip(self.agents, rewards):
            logger.dump_reward(agent.name, agree, reward)
            logging.debug("%s : %s : %s" % (str(agent.name), str(agree), str(rewards)))
            #agent.update(agree, 5-classify_loss.item())


        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        for ctx, choice in zip(ctxs, choices):
            logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards
Exemplo n.º 6
0
class Dialog(object):
    def __init__(self, agents, args, markable_detector,
                 markable_detector_corpus):
        # For now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()
        self.markable_detector = markable_detector
        self.markable_detector_corpus = markable_detector_corpus
        self.selfplay_markables = {}
        self.selfplay_referents = {}

    def _register_metrics(self):
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_moving_percentage('moving_agree')
        self.metrics.register_average('advantage')
        self.metrics.register_moving_average('moving_advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        self.metrics.register_average('agree_comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_moving_average('%s_moving_rew' % agent.name)
            self.metrics.register_average('agree_%s_rew' % agent.name)
            self.metrics.register_percentage('%s_make_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
            if "plot_metrics" in self.args and self.args.plot_metrics:
                self.metrics.register_select_frequency('%s_sel_bias' %
                                                       agent.name)
        # text metrics
        if self.args.ref_text:
            ref_text = ' '.join(data.read_lines(self.args.ref_text))
            self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        return '<selection>' in out

    def show_metrics(self):
        return ' '.join(
            ['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def plot_metrics(self):
        self.metrics.plot()

    def run(self, ctxs, logger, max_words=5000):
        scenario_id = ctxs[0][0]

        for agent, agent_id, ctx, real_ids in zip(self.agents, [0, 1], ctxs[1],
                                                  ctxs[2]):
            agent.feed_context(ctx)
            agent.real_ids = real_ids
            agent.agent_id = agent_id

        # Choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        conv = []
        speaker = []
        self.metrics.reset()

        words_left = max_words
        length = 0
        expired = False

        while True:
            out = writer.write(max_words=words_left)
            words_left -= len(out)
            length += len(out)

            self.metrics.record('sent_len', len(out))
            if 'full_match' in self.metrics.metrics:
                self.metrics.record('full_match', out)
            self.metrics.record('%s_unique' % writer.name, out)

            conv.append(out)
            speaker.append(writer.agent_id)
            reader.read(out)
            if not writer.human:
                logger.dump_sent(writer.name, out)

            if logger.scenarios and self.args.log_attention:
                attention = writer.get_attention()
                if attention is not None:
                    logger.dump_attention(writer.name, writer.agent_id,
                                          scenario_id, attention)

            if self._is_selection(out):
                self.metrics.record('%s_make_sel' % writer.name, 1)
                self.metrics.record('%s_make_sel' % reader.name, 0)
                break

            if words_left <= 1:
                break

            writer, reader = reader, writer

        choices = []
        for agent in self.agents:
            choice = agent.choose()
            choices.append(choice)
        if logger.scenarios:
            logger.dump_choice(scenario_id, choices)
            if "plot_metrics" in self.args and self.args.plot_metrics:
                for agent in [0, 1]:
                    for obj in logger.scenarios[scenario_id]['kbs'][agent]:
                        if obj['id'] == choices[agent]:
                            self.metrics.record(
                                '%s_sel_bias' % writer.name, obj,
                                logger.scenarios[scenario_id]['kbs'][agent])

        agree, rewards = self.domain.score_choices(choices, ctxs)
        if expired:
            agree = False
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        for i, (agent, reward) in enumerate(zip(self.agents, rewards)):
            j = 1 if i == 0 else 0
            agent.update(agree, reward, choice=choices[i])

        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
            self.metrics.record('moving_advantage', rewards[0] - rewards[1])
            self.metrics.record('agree_comb_rew', np.sum(rewards))
            for agent, reward in zip(self.agents, rewards):
                self.metrics.record('agree_%s_rew' % agent.name, reward)

        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('moving_agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)
            self.metrics.record('%s_moving_rew' % agent.name,
                                reward if agree else 0)

        if self.markable_detector is not None and self.markable_detector_corpus is not None:
            markable_list = []
            referents_dict = {}

            markable_starts = []
            for agent in [0, 1]:
                dialog_tokens = []
                dialog_text = ""
                markables = []
                for spkr, uttr in zip(speaker, conv):
                    if spkr == agent:
                        dialog_tokens.append("YOU:")
                    else:
                        dialog_tokens.append("THEM:")
                    dialog_tokens += uttr
                    dialog_text += str(spkr) + ": " + " ".join(
                        uttr[:-1]) + "\n"

                    words = self.markable_detector_corpus.word_dict.w2i(
                        dialog_tokens)
                    words = torch.Tensor(words).long().cuda()
                    score, tag_seq = self.markable_detector(words)
                    referent_inpt = []
                    markable_ids = []
                    my_utterance = None
                    current_text = ""
                    for i, word in enumerate(words):
                        if word.item(
                        ) == self.markable_detector_corpus.word_dict.word2idx[
                                "YOU:"]:
                            my_utterance = True
                            current_speaker = agent
                        elif word.item(
                        ) == self.markable_detector_corpus.word_dict.word2idx[
                                "THEM:"]:
                            my_utterance = False
                            current_speaker = 1 - agent
                        if my_utterance:
                            if tag_seq[i].item(
                            ) == self.markable_detector_corpus.bio_dict["B"]:
                                start_idx = i
                                for j in range(i + 1, len(tag_seq)):
                                    if tag_seq[j].item(
                                    ) != self.markable_detector_corpus.bio_dict[
                                            "I"]:
                                        end_idx = j - 1
                                        break
                                for j in range(i + 1, len(tag_seq)):
                                    if tag_seq[j].item(
                                    ) in self.markable_detector_corpus.word_dict.w2i(
                                        ["<eos>", "<selection>"]):
                                        end_uttr = j
                                        break

                                markable_start = len(current_text + " ")
                                if markable_start not in markable_starts:
                                    referent_inpt.append(
                                        [start_idx, end_idx, end_uttr])
                                    markable_ids.append(len(markable_starts))

                                    # add markable
                                    markable = {}
                                    markable["start"] = markable_start
                                    markable["end"] = len(
                                        current_text + " " + " ".join(
                                            dialog_tokens[start_idx:end_idx +
                                                          1]))
                                    #markable["start"] = len(str(spkr) + ": " + " ".join(dialog_tokens[1:start_idx]) + " ")
                                    #markable["end"] = len(str(spkr) + ": " + " ".join(dialog_tokens[1:end_idx + 1]))
                                    markable["markable_id"] = len(
                                        markable_starts)
                                    markable["speaker"] = current_speaker
                                    markable["text"] = " ".join(
                                        dialog_tokens[start_idx:end_idx + 1])
                                    markable_starts.append(markable["start"])
                                    markable_list.append(markable)

                        if word.item(
                        ) == self.markable_detector_corpus.word_dict.word2idx[
                                "YOU:"]:
                            current_text += "{}:".format(current_speaker)
                        elif word.item(
                        ) == self.markable_detector_corpus.word_dict.word2idx[
                                "THEM:"]:
                            current_text += "{}:".format(current_speaker)
                        elif word.item(
                        ) in self.markable_detector_corpus.word_dict.w2i(
                            ["<eos>", "<selection>"]):
                            current_text += "\n"
                        else:
                            current_text += " " + self.markable_detector_corpus.word_dict.idx2word[
                                word.item()]

                    assert len(current_text) == len(dialog_text)

                    ref_out = self.agents[agent].predict_referents(
                        referent_inpt)

                    if ref_out is not None:
                        for i, markable_id in enumerate(markable_ids):
                            ent_ids = [
                                ent["id"] for ent in
                                logger.scenarios[scenario_id]['kbs'][agent]
                            ]
                            referents = []
                            for j, is_referent in enumerate(
                                (ref_out[i] > 0).tolist()):
                                if is_referent:
                                    referents.append("agent_" + str(agent) +
                                                     "_" + ent_ids[j])

                            referents_dict[markable_id] = referents

            #markable_starts = list(set(markable_starts))
            # reindex markable ids
            markable_id_and_start = [
                (markable_id, markable_start)
                for markable_id, markable_start in zip(
                    range(len(markable_starts)), markable_starts)
            ]
            reindexed_markable_ids = [
                markable_id for markable_id, _ in sorted(markable_id_and_start,
                                                         key=lambda x: x[1])
            ]

            self.selfplay_markables[scenario_id] = {}
            self.selfplay_referents[scenario_id] = {}

            # add markables
            self.selfplay_markables[scenario_id]["markables"] = []
            for new_markable_id, old_markable_id in enumerate(
                    reindexed_markable_ids):
                markable = markable_list[old_markable_id]
                markable["markable_id"] = "M{}".format(new_markable_id + 1)
                self.selfplay_markables[scenario_id]["markables"].append(
                    markable)

            # add dialogue text
            self.selfplay_markables[scenario_id]["text"] = dialog_text

            # add final selections
            self.selfplay_markables[scenario_id]["selections"] = choices

            # add referents
            for new_markable_id, old_markable_id in enumerate(
                    reindexed_markable_ids):
                referents = referents_dict[old_markable_id]
                self.selfplay_referents[scenario_id]["M{}".format(
                    new_markable_id + 1)] = referents

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        #for ctx, choice in zip(ctxs, choices):
        #    logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards
Exemplo n.º 7
0
class Dialog(object):
    """Dialogue runner."""
    def __init__(self, agents, args):
        # for now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()

    def _register_metrics(self):
        """Registers valuable metrics."""
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_average('advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_percentage('%s_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
        # text metrics
        ref_text = ' '.join(data.read_lines(self.args.ref_text))
        self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        return len(out) == 1 and out[0] == '<selection>'

    def show_metrics(self):
        return ' '.join(
            ['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def run(self, ctxs, logger):
        """Runs one instance of the dialogue."""
        assert len(self.agents) == len(ctxs)
        # initialize agents by feeding in the contexes
        #for agent, ctx in zip(self.agents, ctxs):
        #    agent.feed_context(ctx)
        #   logger.dump_ctx(agent.name, ctx)
        self.agents[0].feed_context(ctxs[0])
        logger.dump_ctx(self.agents[0].name, ctxs[0])
        self.agents[1].feed_context(ctxs[1], ctxs[0])
        logger.dump_ctx(self.agents[1].name, ctxs[1])

        logger.dump('-' * 80)

        # choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        #writer, reader = self.agents

        conv = []
        # reset metrics
        self.metrics.reset()

        #### Minhao ####
        count_turns = 0

        while True:
            # produce an utterance
            if count_turns > self.args.max_turns:
                out = writer.write_selection()
            else:
                out = writer.write()

            self.metrics.record('sent_len', len(out))
            self.metrics.record('full_match', out)
            self.metrics.record('%s_unique' % writer.name, out)

            # append the utterance to the conversation
            conv.append(out)
            # make the other agent to read it
            reader.read(out)
            if not writer.human:
                logger.dump_sent(writer.name, out)
            # check if the end of the conversation was generated
            if self._is_selection(out):
                self.metrics.record('%s_sel' % writer.name, 1)
                self.metrics.record('%s_sel' % reader.name, 0)
                break
            writer, reader = reader, writer
            count_turns += 1
            ##### add selection mark if exceeding the max_turns

        ### Minhao: need to design loss focusing on the choices
        ### No evalution in the conversation????

        choices = []
        #class_losses = []
        # generate choices for each of the agents
        #flag_loss=False

        bob_choice, classify_loss = self.agents[1].choose(flag=True)
        choices.append(bob_choice)
        alice_choice = bob_choice[:]
        for indx in range(3):
            alice_choice[
                indx +
                3], alice_choice[indx] = alice_choice[indx], alice_choice[indx
                                                                          + 3]
        choices.append(alice_choice)  ######## always agree
        choices[1], choices[0] = choices[0], choices[1]
        #print(choices)
        #for agent in self.agents:
        #    choice, class_loss = agent.choose(flag=flag_loss)
        #    class_losses.append(class_loss)
        #    choices.append(choice)
        #    logger.dump_choice(agent.name, choice[: self.domain.selection_length() // 2])
        #    flag_loss=True

        # evaluate the choices, produce agreement and a reward
        #print(choices,ctxs)
        agree, rewards = self.domain.score_choices(choices, ctxs)
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        #print(rewards)
        # perform update, in case if any of the agents is learnable
        # let the difference become new reward
        ## how to combine the loss to the reward
        '''
        diff = rewards[0]-rewards[1] 
        flag = True
        agree = 1
        #print(5 - classify_loss.item())
        for agent, reward in zip(self.agents, rewards):           
            if flag:
                logger.dump_reward(agent.name, agree, reward)
                #agent.update(agree, 50-class_losses[1].item())
                agent.update(agree, 5-classify_loss.item())
                #agent.update(agree, diff - 0.05 * class_losses[1].data[0])
                #agent.update(agree, diff)
            else:
                logger.dump_reward(agent.name, agree, reward)
                if not self.args.fixed_bob:
                    agent.update(agree, reward)
            flag=False
        '''
        agree = 1
        for agent, reward in zip(self.agents, rewards):
            logger.dump_reward(agent.name, agree, reward)
            logging.debug("%s : %s : %s" %
                          (str(agent.name), str(agree), str(rewards)))
            agent.update(agree, 5 - classify_loss.item())

        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        for ctx, choice in zip(ctxs, choices):
            logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards
Exemplo n.º 8
0
class Dialog(object):
    """Dialogue runner."""
    def __init__(self, agents, args):
        # for now we only suppport dialog of 2 agents
        assert len(agents) == 2
        self.agents = agents
        self.args = args
        self.domain = domain.get_domain(args.domain)
        self.metrics = MetricsContainer()
        self._register_metrics()

    def _register_metrics(self):
        """Registers valuable metrics."""
        self.metrics.register_average('dialog_len')
        self.metrics.register_average('sent_len')
        self.metrics.register_percentage('agree')
        self.metrics.register_average('advantage')
        self.metrics.register_time('time')
        self.metrics.register_average('comb_rew')
        for agent in self.agents:
            self.metrics.register_average('%s_rew' % agent.name)
            self.metrics.register_percentage('%s_sel' % agent.name)
            self.metrics.register_uniqueness('%s_unique' % agent.name)
        # text metrics
        ref_text = ' '.join(data.read_lines(self.args.ref_text))
        self.metrics.register_ngram('full_match', text=ref_text)

    def _is_selection(self, out):
        return len(out) == 1 and out[0] == '<selection>'

    def show_metrics(self):
        return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])

    def run(self, ctxs, logger):
        """Runs one instance of the dialogue."""
        assert len(self.agents) == len(ctxs)
        # initialize agents by feeding in the contexes
        for agent, ctx in zip(self.agents, ctxs):
            agent.feed_context(ctx)
            logger.dump_ctx(agent.name, ctx)
        logger.dump('-' * 80)

        # choose who goes first by random
        if np.random.rand() < 0.5:
            writer, reader = self.agents
        else:
            reader, writer = self.agents

        conv = []
        # reset metrics
        self.metrics.reset()

        while True:
            # produce an utterance
            out = writer.write()

            self.metrics.record('sent_len', len(out))
            self.metrics.record('full_match', out)
            self.metrics.record('%s_unique' % writer.name, out)

            # append the utterance to the conversation
            conv.append(out)
            # make the other agent to read it
            reader.read(out)
            if not writer.human:
                logger.dump_sent(writer.name, out)
            # check if the end of the conversation was generated
            if self._is_selection(out):
                self.metrics.record('%s_sel' % writer.name, 1)
                self.metrics.record('%s_sel' % reader.name, 0)
                break
            writer, reader = reader, writer


        choices = []
        # generate choices for each of the agents
        for agent in self.agents:
            choice = agent.choose()
            choices.append(choice)
            logger.dump_choice(agent.name, choice[: self.domain.selection_length() // 2])

        # evaluate the choices, produce agreement and a reward
        agree, rewards = self.domain.score_choices(choices, ctxs)
        logger.dump('-' * 80)
        logger.dump_agreement(agree)
        # perform update, in case if any of the agents is learnable
        for agent, reward in zip(self.agents, rewards):
            logger.dump_reward(agent.name, agree, reward)
            agent.update(agree, reward)

        if agree:
            self.metrics.record('advantage', rewards[0] - rewards[1])
        self.metrics.record('time')
        self.metrics.record('dialog_len', len(conv))
        self.metrics.record('agree', int(agree))
        self.metrics.record('comb_rew', np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record('%s_rew' % agent.name, reward if agree else 0)

        logger.dump('-' * 80)
        logger.dump(self.show_metrics())
        logger.dump('-' * 80)
        for ctx, choice in zip(ctxs, choices):
            logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice)))

        return conv, agree, rewards