Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description='Negotiator')
    parser.add_argument('--dataset', type=str, default='./data/negotiate/val.txt',
        help='location of the dataset')
    parser.add_argument('--model_file', type=str,
        help='model file')
    parser.add_argument('--smart_ai', action='store_true', default=False,
        help='to use rollouts')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--log_file', type=str, default='',
        help='log file')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    model = utils.load_model(args.model_file)
    ai = LstmAgent(model, args)
    logger = DialogLogger(verbose=True, log_file=args.log_file)
    domain = get_domain(args.domain)

    score_func = rollout if args.smart_ai else likelihood

    dataset, sents = read_dataset(args.dataset)
    ranks, n, k = 0, 0, 0
    for ctx, dialog in dataset:
        start_time = time.time()
        # start new conversation
        ai.feed_context(ctx)
        for sent, you in dialog:
            if you:
                # if it is your turn to say, take the target word and compute its rank
                rank = compute_rank(sent, sents, ai, domain, args.temperature, score_func)
                # compute lang_h for the groundtruth sentence
                enc = ai._encode(sent, ai.model.word_dict)
                _, ai.lang_h, lang_hs = ai.model.score_sent(enc, ai.lang_h, ai.ctx_h, args.temperature)
                # save hidden states and the utterance
                ai.lang_hs.append(lang_hs)
                ai.words.append(ai.model.word2var('YOU:'))
                ai.words.append(Variable(enc))
                ranks += rank
                n += 1
            else:
                ai.read(sent)
        k += 1
        time_elapsed = time.time() - start_time
        logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' % (k, 1. * ranks / n, ranks, n, time_elapsed))

    logger.dump('final avg rank %.3f' % (1. * ranks / n))
Exemplo n.º 2
0
 def __init__(self, dialog, ctx_gen, args, engine, corpus, logger=None):
     self.dialog = dialog
     self.ctx_gen = ctx_gen
     self.args = args
     self.engine = engine
     self.corpus = corpus
     self.logger = logger if logger else DialogLogger()
Exemplo n.º 3
0
 def __init__(self, dialog, ctx_gen, args, engine, corpus, logger=None):
     """Initialize an RL learning class
     @param dialog: a dialog generator with policy update functionality
     @param ctx_gen: context (i.e. token amounts and values) generator
     @param engine: supervised language model training engine
     @param corpus: training corpus
     """
     self.dialog = dialog
     self.ctx_gen = ctx_gen
     self.args = args
     self.engine = engine
     self.corpus = corpus
     self.logger = logger if logger else DialogLogger()
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(description='Negotiator')
    parser.add_argument('--dataset',
                        type=str,
                        default='./data/negotiate/val.txt',
                        help='location of the dataset')
    parser.add_argument('--model_file', type=str, help='model file')
    parser.add_argument('--smart_ai',
                        action='store_true',
                        default=False,
                        help='to use rollouts')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--log_file', type=str, default='', help='log file')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    model = utils.load_model(args.model_file)
    ai = LstmAgent(model, args)
    logger = DialogLogger(verbose=True, log_file=args.log_file)
    domain = get_domain(args.domain)

    score_func = rollout if args.smart_ai else likelihood

    dataset, sents = read_dataset(args.dataset)
    ranks, n, k = 0, 0, 0
    for ctx, dialog in dataset:
        start_time = time.time()
        ai.feed_context(ctx)
        for sent, you in dialog:
            if you:
                rank = compute_rank(sent, sents, ai, domain, args.temperature,
                                    score_func)
                # Compute lang_h for the groundtruth sentence
                enc = ai._encode(sent, ai.model.word_dict)
                _, ai.lang_h, lang_hs = ai.model.score_sent(
                    enc, ai.lang_h, ai.ctx_h, args.temperature)
                ai.lang_hs.append(lang_hs)
                ai.words.append(ai.model.word2var('YOU:'))
                ai.words.append(Variable(enc))
                ranks += rank
                n += 1
            else:
                ai.read(sent)
        k += 1
        time_elapsed = time.time() - start_time
        logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' %
                    (k, 1. * ranks / n, ranks, n, time_elapsed))

    logger.dump('final avg rank %.3f' % (1. * ranks / n))
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description='chat utility')
    parser.add_argument('--model_file', type=str,
        help='model file')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--context_file', type=str, default='',
        help='context file')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--num_types', type=int, default=3,
        help='number of object types')
    parser.add_argument('--num_objects', type=int, default=6,
        help='total number of objects')
    parser.add_argument('--max_score', type=int, default=10,
        help='max score per object')
    parser.add_argument('--score_threshold', type=int, default=6,
        help='successful dialog should have more than score_threshold in score')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--smart_ai', action='store_true', default=False,
        help='make AI smart again')
    parser.add_argument('--ai_starts', action='store_true', default=False,
        help='allow AI to start the dialog')
    parser.add_argument('--ref_text', type=str,
        help='file with the reference text')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    human = HumanAgent(domain.get_domain(args.domain))

    alice_ty = LstmRolloutAgent if args.smart_ai else LstmAgent
    ai = alice_ty(utils.load_model(args.model_file), args)


    agents = [ai, human] if args.ai_starts else [human, ai]

    dialog = Dialog(agents, args)
    logger = DialogLogger(verbose=True)
    if args.context_file == '':
        ctx_gen = ManualContextGenerator(args.num_types, args.num_objects, args.max_score)
    else:
        ctx_gen = ContextGenerator(args.context_file)

    chat = Chat(dialog, ctx_gen, logger)
    chat.run()
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--data',
                        type=str,
                        default='./data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='discount factor')
    parser.add_argument('--eps', type=float, default=0.5, help='eps greedy')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=False,
                        help='enable nesterov momentum')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        help='momentum for sgd')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=0.1,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=0.1,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=0.1,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--bsz', type=int, default=8, help='batch size')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=-1,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=1,
                        help='number of epochs')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)
    # We don't want to use Dropout during RL
    alice_model.eval()
    alice = RlAgent(alice_model, args, name='Alice')

    bob_ty = LstmRolloutAgent if args.smart_bob else LstmAgent
    bob_model = utils.load_model(args.bob_model_file)
    bob_model.eval()
    bob = bob_ty(bob_model, args, name='Bob')

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold)
    engine = Engine(alice_model, args, device_id, verbose=False)

    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    utils.save_model(alice.model, args.output_model_file)
Exemplo n.º 7
0
 def __init__(self, dialog, ctx_gen, logger=None):
     self.dialog = dialog
     self.ctx_gen = ctx_gen
     self.logger = logger if logger else DialogLogger()
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--pred_temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='discount factor')
    parser.add_argument('--eps', type=float, default=0.5, help='eps greedy')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.1,
                        help='momentum for sgd')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=0.1,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=0.002,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=2.0,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=-1,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=1,
                        help='number of epochs')
    parser.add_argument('--hierarchical',
                        action='store_true',
                        default=False,
                        help='use hierarchical training')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--selection_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--data',
                        type=str,
                        default='data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16, help='batch size')
    parser.add_argument('--validate',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--scratch',
                        action='store_true',
                        default=False,
                        help='erase prediciton weights')
    parser.add_argument('--sep_sel',
                        action='store_true',
                        default=False,
                        help='use separate classifiers for selection')

    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)  # RnnModel
    alice_ty = get_agent_type(alice_model)  # RnnRolloutAgent
    alice = alice_ty(alice_model, args, name='Alice', train=True)
    alice.vis = args.visual

    bob_model = utils.load_model(args.bob_model_file)  # RnnModel
    bob_ty = get_agent_type(bob_model)  # RnnAgent
    bob = bob_ty(bob_model, args, name='Bob', train=False)

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    domain = get_domain(args.domain)
    corpus = alice_model.corpus_ty(domain,
                                   args.data,
                                   freq_cutoff=args.unk_threshold,
                                   verbose=True,
                                   sep_sel=args.sep_sel)
    engine = alice_model.engine_ty(alice_model, args)

    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    utils.save_model(alice.model, args.output_model_file)
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--data',
                        type=str,
                        default=config.data_dir,
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=config.unk_threshold,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=config.rl_temperature,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=config.cuda,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=config.verbose,
                        help='print out converations')
    parser.add_argument('--seed',
                        type=int,
                        default=config.seed,
                        help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=config.rl_score_threshold,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=config.rl_gamma,
                        help='discount factor')
    parser.add_argument('--eps',
                        type=float,
                        default=config.rl_eps,
                        help='eps greedy')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=config.nesterov,
                        help='enable nesterov momentum')
    parser.add_argument('--momentum',
                        type=float,
                        default=config.rl_momentum,
                        help='momentum for sgd')
    parser.add_argument('--lr',
                        type=float,
                        default=config.rl_lr,
                        help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=config.rl_clip,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=config.rl_reinforcement_lr,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=config.rl_reinforcement_clip,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--bsz',
                        type=int,
                        default=config.rl_bsz,
                        help='batch size')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=config.rl_sv_train_freq,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=config.rl_nepoch,
                        help='number of epochs')
    parser.add_argument('--visual',
                        action='store_true',
                        default=config.plot_graphs,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default=config.domain,
                        help='domain for the dialogue')
    parser.add_argument('--reward',
                        type=str,
                        choices=['margin', 'fair', 'length'],
                        default='margin',
                        help='reward function')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    logging.info("Starting training using pytorch version:%s" %
                 (str(torch.__version__)))
    logging.info("CUDA is %s" % ("enabled. Using device_id:"+str(device_id) + " version:" \
        +str(torch.version.cuda) + " on gpu:" + torch.cuda.get_device_name(0) if args.cuda else "disabled"))

    alice_model = utils.load_model(args.alice_model_file)
    # we don't want to use Dropout during RL
    alice_model.eval()
    # Alice is a RL based agent, meaning that she will be learning while selfplaying
    logging.info("Creating RlAgent from alice_model: %s" %
                 (args.alice_model_file))
    alice = RlAgent(alice_model, args, name='Alice')

    # we keep Bob frozen, i.e. we don't update his parameters
    logging.info("Creating Bob's (--smart_bob) LstmRolloutAgent" if args.smart_bob \
        else "Creating Bob's (not --smart_bob) LstmAgent" )
    bob_ty = LstmRolloutAgent if args.smart_bob else LstmAgent
    bob_model = utils.load_model(args.bob_model_file)
    bob_model.eval()
    bob = bob_ty(bob_model, args, name='Bob')

    logging.info("Initializing communication dialogue between Alice and Bob")
    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    logging.info(
        "Building word corpus, requiring minimum word frequency of %d for dictionary"
        % (args.unk_threshold))
    corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold)
    engine = Engine(alice_model, args, device_id, verbose=False)

    logging.info("Starting Reinforcement Learning")
    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    logging.info("Saving updated Alice model to %s" % (args.output_model_file))
    utils.save_model(alice.model, args.output_model_file)
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(description='selfplaying script')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--max_turns',
                        type=int,
                        default=20,
                        help='maximum number of turns in a dialog')
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_alice',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--fast_rollout',
                        action='store_true',
                        default=False,
                        help='to use faster rollouts')
    parser.add_argument('--rollout_bsz',
                        type=int,
                        default=100,
                        help='rollout batch size')
    parser.add_argument('--rollout_count_threshold',
                        type=int,
                        default=3,
                        help='rollout count threshold')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)
    alice_ty = get_agent_type(alice_model, args.smart_alice, args.fast_rollout)
    alice = alice_ty(alice_model, args, name='Alice')

    bob_model = utils.load_model(args.bob_model_file)
    bob_ty = get_agent_type(bob_model, args.smart_bob, args.fast_rollout)
    bob = bob_ty(bob_model, args, name='Bob')

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    selfplay = SelfPlay(dialog, ctx_gen, args, logger)
    selfplay.run()
Exemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(description='selfplaying script')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--alice_forward_model_file',
                        type=str,
                        help='Alice forward model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--pred_temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--max_turns',
                        type=int,
                        default=20,
                        help='maximum number of turns in a dialog')
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_alice',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--diverse_alice',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--rollout_bsz',
                        type=int,
                        default=3,
                        help='rollout batch size')
    parser.add_argument('--rollout_count_threshold',
                        type=int,
                        default=3,
                        help='rollout count threshold')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--selection_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--rollout_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--diverse_bob',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--eps', type=float, default=0.0, help='eps greedy')
    parser.add_argument('--data',
                        type=str,
                        default='data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16, help='batch size')
    parser.add_argument('--validate',
                        action='store_true',
                        default=False,
                        help='plot graphs')

    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)
    alice_ty = get_agent_type(alice_model, args.smart_alice)
    alice = alice_ty(alice_model,
                     args,
                     name='Alice',
                     train=False,
                     diverse=args.diverse_alice)
    alice.vis = args.visual

    bob_model = utils.load_model(args.bob_model_file)
    bob_ty = get_agent_type(bob_model, args.smart_bob)
    bob = bob_ty(bob_model,
                 args,
                 name='Bob',
                 train=False,
                 diverse=args.diverse_bob)

    bob.vis = False

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    selfplay = SelfPlay(dialog, ctx_gen, args, logger)
    selfplay.run()
Exemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser(description='selfplaying script')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--alice_forward_model_file',
                        type=str,
                        help='Alice forward model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--pred_temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--log_attention',
                        action='store_true',
                        default=False,
                        help='log attention')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--max_turns',
                        type=int,
                        default=20,
                        help='maximum number of turns in a dialog')
    parser.add_argument('--log_file',
                        type=str,
                        default='selfplay.log',
                        help='log dialogs to file')
    parser.add_argument('--smart_alice',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--rollout_bsz',
                        type=int,
                        default=3,
                        help='rollout batch size')
    parser.add_argument('--rollout_count_threshold',
                        type=int,
                        default=3,
                        help='rollout count threshold')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--selection_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--rollout_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--domain',
                        type=str,
                        default='one_common',
                        help='domain for the dialogue')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--eps', type=float, default=0.0, help='eps greedy')
    parser.add_argument('--data',
                        type=str,
                        default='data/onecommon',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=10,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16, help='batch size')
    parser.add_argument('--plot_metrics',
                        action='store_true',
                        default=False,
                        help='plot metrics')
    parser.add_argument('--markable_detector_file',
                        type=str,
                        default="markable_detector",
                        help='visualize referents')
    parser.add_argument('--record_markables',
                        action='store_true',
                        default=False,
                        help='record markables and referents')
    parser.add_argument('--repeat_selfplay',
                        action='store_true',
                        default=False,
                        help='repeat selfplay')

    args = parser.parse_args()

    if args.repeat_selfplay:
        seeds = list(range(10))
    else:
        seeds = [args.seed]

    repeat_results = []

    for seed in seeds:
        utils.use_cuda(args.cuda)
        utils.set_seed(args.seed)

        if args.record_markables:
            if not os.path.exists(args.markable_detector_file + '_' +
                                  str(seed) + '.th'):
                assert False
            markable_detector = utils.load_model(args.markable_detector_file +
                                                 '_' + str(seed) + '.th')
            if args.cuda:
                markable_detector.cuda()
            else:
                device = torch.device("cpu")
                markable_detector.to(device)
            markable_detector.eval()
            markable_detector_corpus = markable_detector.corpus_ty(
                domain,
                args.data,
                train='train_markable_{}.txt'.format(seed),
                valid='valid_markable_{}.txt'.format(seed),
                test='test_markable_{}.txt'.format(
                    seed),  #test='selfplay_reference_{}.txt'.format(seed),
                freq_cutoff=args.unk_threshold,
                verbose=True)
        else:
            markable_detector = None
            markable_detector_corpus = None

        alice_model = utils.load_model(args.alice_model_file + '_' +
                                       str(seed) + '.th')
        alice_ty = get_agent_type(alice_model, args.smart_alice)
        alice = alice_ty(alice_model, args, name='Alice', train=False)

        bob_model = utils.load_model(args.bob_model_file + '_' + str(seed) +
                                     '.th')
        bob_ty = get_agent_type(bob_model, args.smart_bob)
        bob = bob_ty(bob_model, args, name='Bob', train=False)

        dialog = Dialog([alice, bob], args, markable_detector,
                        markable_detector_corpus)
        ctx_gen = ContextGenerator(
            os.path.join(args.data, args.context_file + '.txt'))
        with open(os.path.join(args.data, args.context_file + '.json'),
                  "r") as f:
            scenario_list = json.load(f)
        scenarios = {scenario['uuid']: scenario for scenario in scenario_list}
        logger = DialogLogger(verbose=args.verbose,
                              log_file=args.log_file,
                              scenarios=scenarios)

        selfplay = SelfPlay(dialog, ctx_gen, args, logger)
        result = selfplay.run()
        repeat_results.append(result)

    print("dump selfplay_markables.json")
    dump_json(dialog.selfplay_markables, "selfplay_markables.json")
    print("dump selfplay_referents.json")
    dump_json(dialog.selfplay_referents, "selfplay_referents.json")

    print("repeat selfplay results %.8f ( %.8f )" %
          (np.mean(repeat_results), np.std(repeat_results)))