Beispiel #1
0
def main():
    parser = argparse.ArgumentParser(description='selfplaying script')
    parser.add_argument('--alice_model_file', type=str,
        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str,
        help='Bob model file')
    parser.add_argument('--context_file', type=str,
        help='context file')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--verbose', action='store_true', default=False,
        help='print out converations')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--score_threshold', type=int, default=6,
        help='successful dialog should have more than score_threshold in score')
    parser.add_argument('--max_turns', type=int, default=20,
        help='maximum number of turns in a dialog')
    parser.add_argument('--log_file', type=str, default='',
        help='log successful dialogs to file for training')
    parser.add_argument('--smart_alice', action='store_true', default=False,
        help='make Alice smart again')
    parser.add_argument('--fast_rollout', action='store_true', default=False,
        help='to use faster rollouts')
    parser.add_argument('--rollout_bsz', type=int, default=100,
        help='rollout batch size')
    parser.add_argument('--rollout_count_threshold', type=int, default=3,
        help='rollout count threshold')
    parser.add_argument('--smart_bob', action='store_true', default=False,
        help='make Bob smart again')
    parser.add_argument('--ref_text', type=str,
        help='file with the reference text')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--fixed_bob', action='store_true', default=False,
        help='make Bob smart again')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)
    alice_ty = get_agent_type(alice_model, args.smart_alice, args.fast_rollout)
    alice = alice_ty(alice_model, args, name='Alice')
    alice_model.train()

    bob_model = utils.load_model(args.bob_model_file)
    bob_ty = get_agent_type(bob_model, args.smart_bob, args.fast_rollout)
    bob = bob_ty(bob_model, args, name='Bob')
    bob_model.train()
    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    selfplay = SelfPlay(dialog, ctx_gen, args, logger)
    selfplay.run()
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser(description='chat utility')
    parser.add_argument('--model_file', type=str,
        help='model file')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--context_file', type=str, default='',
        help='context file')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--num_types', type=int, default=3,
        help='number of object types')
    parser.add_argument('--num_objects', type=int, default=6,
        help='total number of objects')
    parser.add_argument('--max_score', type=int, default=10,
        help='max score per object')
    parser.add_argument('--score_threshold', type=int, default=6,
        help='successful dialog should have more than score_threshold in score')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--smart_ai', action='store_true', default=False,
        help='make AI smart again')
    parser.add_argument('--ai_starts', action='store_true', default=False,
        help='allow AI to start the dialog')
    parser.add_argument('--ref_text', type=str,
        help='file with the reference text')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    human = HumanAgent(domain.get_domain(args.domain))

    alice_ty = LstmRolloutAgent if args.smart_ai else LstmAgent
    ai = alice_ty(utils.load_model(args.model_file), args)


    agents = [ai, human] if args.ai_starts else [human, ai]

    dialog = Dialog(agents, args)
    logger = DialogLogger(verbose=True)
    if args.context_file == '':
        ctx_gen = ManualContextGenerator(args.num_types, args.num_objects, args.max_score)
    else:
        ctx_gen = ContextGenerator(args.context_file)

    chat = Chat(dialog, ctx_gen, logger)
    chat.run()
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--data',
                        type=str,
                        default='./data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='discount factor')
    parser.add_argument('--eps', type=float, default=0.5, help='eps greedy')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=False,
                        help='enable nesterov momentum')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        help='momentum for sgd')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=0.1,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=0.1,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=0.1,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--bsz', type=int, default=8, help='batch size')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=-1,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=1,
                        help='number of epochs')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)
    # We don't want to use Dropout during RL
    alice_model.eval()
    alice = RlAgent(alice_model, args, name='Alice')

    bob_ty = LstmRolloutAgent if args.smart_bob else LstmAgent
    bob_model = utils.load_model(args.bob_model_file)
    bob_model.eval()
    bob = bob_ty(bob_model, args, name='Bob')

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold)
    engine = Engine(alice_model, args, device_id, verbose=False)

    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    utils.save_model(alice.model, args.output_model_file)
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--pred_temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='discount factor')
    parser.add_argument('--eps', type=float, default=0.5, help='eps greedy')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.1,
                        help='momentum for sgd')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=0.1,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=0.002,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=2.0,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=-1,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=1,
                        help='number of epochs')
    parser.add_argument('--hierarchical',
                        action='store_true',
                        default=False,
                        help='use hierarchical training')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--selection_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--data',
                        type=str,
                        default='data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16, help='batch size')
    parser.add_argument('--validate',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--scratch',
                        action='store_true',
                        default=False,
                        help='erase prediciton weights')
    parser.add_argument('--sep_sel',
                        action='store_true',
                        default=False,
                        help='use separate classifiers for selection')

    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)  # RnnModel
    alice_ty = get_agent_type(alice_model)  # RnnRolloutAgent
    alice = alice_ty(alice_model, args, name='Alice', train=True)
    alice.vis = args.visual

    bob_model = utils.load_model(args.bob_model_file)  # RnnModel
    bob_ty = get_agent_type(bob_model)  # RnnAgent
    bob = bob_ty(bob_model, args, name='Bob', train=False)

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    domain = get_domain(args.domain)
    corpus = alice_model.corpus_ty(domain,
                                   args.data,
                                   freq_cutoff=args.unk_threshold,
                                   verbose=True,
                                   sep_sel=args.sep_sel)
    engine = alice_model.engine_ty(alice_model, args)

    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    utils.save_model(alice.model, args.output_model_file)
Beispiel #5
0
sv_model = SVModel(corpus, 'SV')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "RL")
    saver = tf.train.Saver(var_list=var_list)
    saver.restore(sess, 'model/rl-saved4-99-server/model-3')
    print('rl restored')

    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "SV")
    saver = tf.train.Saver(var_list=var_list)
    saver.restore(sess, 'model/sv-named99/sv-named-99')
    print('sv restored')

    ctx_gen = ContextGenerator(
        'end-to-end-negotiator/src/data/negotiate/selfplay.txt')
    ctxs_all = ctx_gen.get_ctxs()

    rl_agent = RlAgent(sess, rl_model, use_rollouts=USE_ROLLOUTS)
    sv_agent = SVAgent(sess, sv_model)

    sv_rewards, rl_rewards = [], []
    for counter in range(total_games):
        ctxs = random.choice(ctxs_all)
        print_ctxs(ctxs)
        count_sv, count_rl, val_sv, val_rl = get_counts_vals(corpus, ctxs)

        rl_agent.feed_context(count_rl, val_rl)
        sv_agent.feed_context(count_sv, val_sv)

        sv_first = np.random.randint(2) == 0
Beispiel #6
0
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--data',
                        type=str,
                        default=config.data_dir,
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=config.unk_threshold,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=config.rl_temperature,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=config.cuda,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=config.verbose,
                        help='print out converations')
    parser.add_argument('--seed',
                        type=int,
                        default=config.seed,
                        help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=config.rl_score_threshold,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=config.rl_gamma,
                        help='discount factor')
    parser.add_argument('--eps',
                        type=float,
                        default=config.rl_eps,
                        help='eps greedy')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=config.nesterov,
                        help='enable nesterov momentum')
    parser.add_argument('--momentum',
                        type=float,
                        default=config.rl_momentum,
                        help='momentum for sgd')
    parser.add_argument('--lr',
                        type=float,
                        default=config.rl_lr,
                        help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=config.rl_clip,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=config.rl_reinforcement_lr,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=config.rl_reinforcement_clip,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--bsz',
                        type=int,
                        default=config.rl_bsz,
                        help='batch size')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=config.rl_sv_train_freq,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=config.rl_nepoch,
                        help='number of epochs')
    parser.add_argument('--visual',
                        action='store_true',
                        default=config.plot_graphs,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default=config.domain,
                        help='domain for the dialogue')
    parser.add_argument('--reward',
                        type=str,
                        choices=['margin', 'fair', 'length'],
                        default='margin',
                        help='reward function')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    logging.info("Starting training using pytorch version:%s" %
                 (str(torch.__version__)))
    logging.info("CUDA is %s" % ("enabled. Using device_id:"+str(device_id) + " version:" \
        +str(torch.version.cuda) + " on gpu:" + torch.cuda.get_device_name(0) if args.cuda else "disabled"))

    alice_model = utils.load_model(args.alice_model_file)
    # we don't want to use Dropout during RL
    alice_model.eval()
    # Alice is a RL based agent, meaning that she will be learning while selfplaying
    logging.info("Creating RlAgent from alice_model: %s" %
                 (args.alice_model_file))
    alice = RlAgent(alice_model, args, name='Alice')

    # we keep Bob frozen, i.e. we don't update his parameters
    logging.info("Creating Bob's (--smart_bob) LstmRolloutAgent" if args.smart_bob \
        else "Creating Bob's (not --smart_bob) LstmAgent" )
    bob_ty = LstmRolloutAgent if args.smart_bob else LstmAgent
    bob_model = utils.load_model(args.bob_model_file)
    bob_model.eval()
    bob = bob_ty(bob_model, args, name='Bob')

    logging.info("Initializing communication dialogue between Alice and Bob")
    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    logging.info(
        "Building word corpus, requiring minimum word frequency of %d for dictionary"
        % (args.unk_threshold))
    corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold)
    engine = Engine(alice_model, args, device_id, verbose=False)

    logging.info("Starting Reinforcement Learning")
    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    logging.info("Saving updated Alice model to %s" % (args.output_model_file))
    utils.save_model(alice.model, args.output_model_file)
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser(description='selfplaying script')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--alice_forward_model_file',
                        type=str,
                        help='Alice forward model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--pred_temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--max_turns',
                        type=int,
                        default=20,
                        help='maximum number of turns in a dialog')
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_alice',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--diverse_alice',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--rollout_bsz',
                        type=int,
                        default=3,
                        help='rollout batch size')
    parser.add_argument('--rollout_count_threshold',
                        type=int,
                        default=3,
                        help='rollout count threshold')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--selection_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--rollout_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--diverse_bob',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--eps', type=float, default=0.0, help='eps greedy')
    parser.add_argument('--data',
                        type=str,
                        default='data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16, help='batch size')
    parser.add_argument('--validate',
                        action='store_true',
                        default=False,
                        help='plot graphs')

    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)
    alice_ty = get_agent_type(alice_model, args.smart_alice)
    alice = alice_ty(alice_model,
                     args,
                     name='Alice',
                     train=False,
                     diverse=args.diverse_alice)
    alice.vis = args.visual

    bob_model = utils.load_model(args.bob_model_file)
    bob_ty = get_agent_type(bob_model, args.smart_bob)
    bob = bob_ty(bob_model,
                 args,
                 name='Bob',
                 train=False,
                 diverse=args.diverse_bob)

    bob.vis = False

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    selfplay = SelfPlay(dialog, ctx_gen, args, logger)
    selfplay.run()
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser(description='selfplaying script')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--alice_forward_model_file',
                        type=str,
                        help='Alice forward model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--pred_temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--log_attention',
                        action='store_true',
                        default=False,
                        help='log attention')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--max_turns',
                        type=int,
                        default=20,
                        help='maximum number of turns in a dialog')
    parser.add_argument('--log_file',
                        type=str,
                        default='selfplay.log',
                        help='log dialogs to file')
    parser.add_argument('--smart_alice',
                        action='store_true',
                        default=False,
                        help='make Alice smart again')
    parser.add_argument('--rollout_bsz',
                        type=int,
                        default=3,
                        help='rollout batch size')
    parser.add_argument('--rollout_count_threshold',
                        type=int,
                        default=3,
                        help='rollout count threshold')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--selection_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--rollout_model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--domain',
                        type=str,
                        default='one_common',
                        help='domain for the dialogue')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--eps', type=float, default=0.0, help='eps greedy')
    parser.add_argument('--data',
                        type=str,
                        default='data/onecommon',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=10,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16, help='batch size')
    parser.add_argument('--plot_metrics',
                        action='store_true',
                        default=False,
                        help='plot metrics')
    parser.add_argument('--markable_detector_file',
                        type=str,
                        default="markable_detector",
                        help='visualize referents')
    parser.add_argument('--record_markables',
                        action='store_true',
                        default=False,
                        help='record markables and referents')
    parser.add_argument('--repeat_selfplay',
                        action='store_true',
                        default=False,
                        help='repeat selfplay')

    args = parser.parse_args()

    if args.repeat_selfplay:
        seeds = list(range(10))
    else:
        seeds = [args.seed]

    repeat_results = []

    for seed in seeds:
        utils.use_cuda(args.cuda)
        utils.set_seed(args.seed)

        if args.record_markables:
            if not os.path.exists(args.markable_detector_file + '_' +
                                  str(seed) + '.th'):
                assert False
            markable_detector = utils.load_model(args.markable_detector_file +
                                                 '_' + str(seed) + '.th')
            if args.cuda:
                markable_detector.cuda()
            else:
                device = torch.device("cpu")
                markable_detector.to(device)
            markable_detector.eval()
            markable_detector_corpus = markable_detector.corpus_ty(
                domain,
                args.data,
                train='train_markable_{}.txt'.format(seed),
                valid='valid_markable_{}.txt'.format(seed),
                test='test_markable_{}.txt'.format(
                    seed),  #test='selfplay_reference_{}.txt'.format(seed),
                freq_cutoff=args.unk_threshold,
                verbose=True)
        else:
            markable_detector = None
            markable_detector_corpus = None

        alice_model = utils.load_model(args.alice_model_file + '_' +
                                       str(seed) + '.th')
        alice_ty = get_agent_type(alice_model, args.smart_alice)
        alice = alice_ty(alice_model, args, name='Alice', train=False)

        bob_model = utils.load_model(args.bob_model_file + '_' + str(seed) +
                                     '.th')
        bob_ty = get_agent_type(bob_model, args.smart_bob)
        bob = bob_ty(bob_model, args, name='Bob', train=False)

        dialog = Dialog([alice, bob], args, markable_detector,
                        markable_detector_corpus)
        ctx_gen = ContextGenerator(
            os.path.join(args.data, args.context_file + '.txt'))
        with open(os.path.join(args.data, args.context_file + '.json'),
                  "r") as f:
            scenario_list = json.load(f)
        scenarios = {scenario['uuid']: scenario for scenario in scenario_list}
        logger = DialogLogger(verbose=args.verbose,
                              log_file=args.log_file,
                              scenarios=scenarios)

        selfplay = SelfPlay(dialog, ctx_gen, args, logger)
        result = selfplay.run()
        repeat_results.append(result)

    print("dump selfplay_markables.json")
    dump_json(dialog.selfplay_markables, "selfplay_markables.json")
    print("dump selfplay_referents.json")
    dump_json(dialog.selfplay_referents, "selfplay_referents.json")

    print("repeat selfplay results %.8f ( %.8f )" %
          (np.mean(repeat_results), np.std(repeat_results)))