コード例 #1
0
ファイル: self_chat.py プロジェクト: sawravchy/ParlAI
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    partner_opt_file = opt.get('partner_opt_file')

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)

    # Set IDs
    agent1.id = agent1.id + "_1"
    agent2.id = agent2.id + "_2"

    model_id = agent1.id + "_" + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
コード例 #2
0
def self_chat(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: self_chat should be passed opt not Parser ]')
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()
    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    # Check for `selfchat` in the task name
    if 'selfchat' not in opt['task']:
        warn_once(
            'You are using self chat with task {}. '.format(opt['task'])
            + 'If your task has an existing self chat world, then run with '
            '-t {}:selfchat'.format(opt['task'])
        )

    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_cnt = opt['num_examples']
    cnt = 0
    while cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        logger.log(world)

        if opt.get('display_examples'):
            print(world.display())
        if log_time.time() > log_every_n_secs:
            text = log_time.log(cnt, max_cnt)
            print(text)

    if opt.get('display_examples'):
        print('-- end of episode --')

    logger.reset_world()  # flush last episode
    logger.write(opt['outfile'], opt['format'])
    return logger.get_logs()
コード例 #3
0
ファイル: self_chat.py プロジェクト: shatu/ParlAI
def self_chat(opt):
    random.seed(opt['seed'])

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()

    # Set IDs
    model_id = agent1.id
    agent1.id = model_id + "_1"
    agent2.id = model_id + "_2"

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
コード例 #4
0
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    assert partner is not None
    partner_opt_file = opt.get('partner_opt_file')
    if partner_opt_file:
        assert partner_opt_file == partner + '.opt', (
            'Unless you think it is save,'
            ' you can remove assert')
    else:
        partner_opt_file = partner + '.opt'

    # Create agents
    if opt['model_file'].split(':')[0] == 'human':
        agent1 = MyLocalHumanAgent(opt)
        assert partner is not None
    else:
        agent1 = create_agent(opt, requireModelExists=True)
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)

    # Set IDs
    agent1.id = agent1.id + '_1'
    agent2.id = agent2.id + '_2'

    model_id = agent1.id + '_' + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    all_report = []
    if opt['num_self_chats'] < 0:
        opt['num_self_chats'] = len(world.messages)

    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)
        all_report.append(report)

        world.write(logger, all_report, opt['outfile'])

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, all_report, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()