Ejemplo n.º 1
0
 def __init__(self, parser):
     opt = parser.parse_args()
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         if opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
     self.val_every_n_secs = opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
     self.best_valid = 0
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
Ejemplo n.º 2
0
def eval_model(opt, parser, printargs=True):
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)
    # Show arguments after loading model
    parser.opt = agent.opt
    if (printargs):
        parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    for _ in range(int(opt['num_examples'])):
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
            log_time.reset()
        if world.epoch_done():
            print("EPOCH DONE")
            break
    print(world.report())
    world.shutdown()
Ejemplo n.º 3
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-e', '--num-epochs', default=1)
    parser.add_argument('-mtt',
                        '--max-train-time',
                        type=float,
                        default=float('inf'))
    opt = parser.parse_args()
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    print("[training...]")
    for i in range(opt['num_epochs'] * len(world)):
        world.parley()
        if opt['display_examples']:
            print(world.display() + "\n~~")
        if train_time.time() > opt['max_train_time']:
            print("[max_train_time elapsed: " + str(train_time.time()) + "]")
            break
    world.shutdown()

    if opt['model_file']:
        agent.save(opt['model_file'])
    run_eval(agent, opt, 'valid')
    run_eval(agent, opt, 'test')
Ejemplo n.º 4
0
def main(opt):
    # Check options
    assert ('pretrained_model' in opt)
    assert (opt['datatype'] in {'test', 'valid'})

    # Load document reader
    doc_reader = DocReaderAgent(opt)

    # Log params
    logger.info(
        '[ Created with options: ] %s' %
        ''.join(['\n{}\t{}'.format(k, v) for k, v in doc_reader.opt.items()]))

    logger.info('[ Running validation... ]')
    valid_world = create_task(opt, doc_reader)
    valid_time = Timer()
    for _ in valid_world:
        valid_world.parley()

    metrics = valid_world.report()
    if 'tasks' in metrics:
        for task, t_metrics in metrics['tasks'].items():
            logger.info('task = %s | EM = %.4f | F1 = %.4f | exs = %d | ' %
                        (task, t_metrics['accuracy'], t_metrics['f1'],
                         t_metrics['total']))
        logger.info('Overall EM = %.4f | exs = %d' %
                    (metrics['accuracy'], metrics['total']))
    else:
        logger.info('EM = %.4f | F1 = %.4f | exs = %d' %
                    (metrics['accuracy'], metrics['f1'], metrics['total']))
    logger.info('[ Done. Time = %.2f (s) ]' % valid_time.time())
Ejemplo n.º 5
0
def validate(opt, agent, n_iter):
    opt = copy.deepcopy(opt)
    opt['datatype'] = 'valid'
    valid_world = create_task(opt, agent)

    logger.info('[ Running validation... ]')
    valid_time = Timer()
    for _ in valid_world:
        valid_world.parley()

    metrics = valid_world.report()
    if 'tasks' in metrics:
        for task, t_metrics in metrics['tasks'].items():
            logger.info('[valid] task = %s | iter = %d | exs = %d | ' %
                        (task, n_iter, t_metrics['total']) +
                        'EM = %.4f | F1 = %.4f' %
                        (t_metrics['accuracy'], t_metrics['f1']))
        logger.info('[valid] iter = %d | overall EM = %.4f | exs = %d' %
                    (n_iter, metrics['accuracy'], metrics['total']))
    else:
        logger.info(
            '[valid] iter = %d | EM = %.4f | F1 = %.4f | exs = %d' %
            (n_iter, metrics['accuracy'], metrics['f1'], metrics['total'])
        )
    logger.info('[ Done. Time = %.2f (s) ]' % valid_time.time())

    return metrics[opt['valid_metric']]
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file_transmitter') and opt.get('model_file_receiver'):
                opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt['model_file_receiver']  + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=False)

        # Create model and assign it to the specified task
        print("[ create meta-agent ... ]")
        self.agent = create_agent(opt)
        print("[ create agent A ... ]")
        shared = self.agent.share()
        self.agent_a = create_agent_from_shared(shared)
        self.agent_a.set_id(suffix=' A')
        print("[ create agent B ... ]")
        self.agent_b = create_agent_from_shared(shared)
        # self.agent_b = create_agent(opt)
        self.agent_b.set_id(suffix=' B')
        # self.agent_a.copy(self.agent, 'transmitter')
        # self.agent_b.copy(self.agent, 'transmitter')
        self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b])

        # TODO: if batch, it is also not parallel
        # self.world = BatchSelfPlayWorld(opt, self_play_world)

        self.train_time = Timer()
        self.train_dis_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys_episode = 0
        self.max_num_epochs = opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
        self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt['train_display_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file_transmitter') and os.path.isfile(opt['model_file_transmitter'] + '.best_valid'):
            with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Ejemplo n.º 7
0
def main(opt):
    # Build dictionary from task data
    if 'pretrained_model' in opt:
        dictionary = None
    else:
        dictionary = build_dict(opt)

    # Build document reader
    doc_reader = DocReaderAgent(opt, word_dict=dictionary)

    # Log params
    logger.info('[ Created with options: ] %s' %
                ''.join(['\n{}\t{}'.format(k, v)
                         for k, v in doc_reader.opt.items()]))

    # Build training world once
    opt['datatype'] = 'train'
    train_world = create_task(opt, doc_reader)
    train_time = Timer()

    # Keep track of best model + how long since the last improvement
    best_valid = 0
    impatience = 0

    logger.info("[ Ok, let's go... ]")
    iteration = 0
    while impatience < opt['patience']:
        # Train...
        logger.info('[ Training for %d iters... ]' % opt['train_interval'])
        train_time.reset()
        for _ in range(opt['train_interval']):
            train_world.parley()
        logger.info('[ Done. Time = %.2f (s) ]' % train_time.time())

        # ...validate!
        valid_metric = validate(opt, doc_reader, iteration)
        if valid_metric > best_valid:
            logger.info(
                '[ Best eval %d: %s = %.4f (old = %.4f) ]' %
                (iteration, opt['valid_metric'], valid_metric, best_valid)
            )
            best_valid = valid_metric
            impatience = 0
            if 'model_file' in opt:
                doc_reader.save(opt['model_file'])

            if valid_metric == 1:
                logger.info('[ Task solved! Stopping. ]')
                break
        else:
            impatience += 1

        iteration += 1
Ejemplo n.º 8
0
 def __init__(self, parser):
     opt = parser.parse_args()
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         if opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict.build_dict(opt)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.total_episodes = 0
     self.total_epochs = 0
     self.max_num_epochs = opt[
         'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt[
         'max_train_time'] if opt['max_train_time'] > 0 else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt[
         'log_every_n_secs'] > 0 else float('inf')
     self.val_every_n_secs = opt['validation_every_n_secs'] if opt[
         'validation_every_n_secs'] > 0 else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt[
         'save_every_n_secs'] > 0 else float('inf')
     self.best_valid = 0
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
Ejemplo n.º 9
0
 def __init__(self, opt, agents=None, shared=None):
     self.id = opt['task']
     self.opt = copy.deepcopy(opt)
     if shared:
         # Create agents based on shared data.
         self.agents = create_agents_from_shared(shared['agents'])
     else:
         # Add passed in agents to world directly.
         self.agents = agents
     self.max_exs = None
     self.total_exs = 0
     self.total_epochs = 0
     self.total_parleys = 0
     self.time = Timer()
Ejemplo n.º 10
0
 def __init__(self, opt):
     if isinstance(opt, ParlaiParser):
         print(
             '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
         )
         opt = opt.parse_args()
     # Possibly load from checkpoint
     if opt['load_from_checkpoint'] and opt.get(
             'model_file') and os.path.isfile(opt['model_file'] +
                                              '.checkpoint'):
         opt['init_model'] = opt['model_file'] + '.checkpoint'
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         # If data built via pytorch data teacher, we need to load prebuilt dict
         if opt.get('pytorch_teacher_task'):
             opt['dict_file'] = get_pyt_dict_file(opt)
         elif opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt, skip_if_built=True)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt[
         'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
         else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
         else float('inf')
     self.val_every_n_secs = \
         opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
         else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
         > 0 else float('inf')
     self.val_every_n_epochs = \
         opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
         else float('inf')
     self.last_valid_epoch = 0
     self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
     self.best_valid = None
     if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                 '.best_valid'):
         with open(opt['model_file'] + ".best_valid", 'r') as f:
             x = f.readline()
             self.best_valid = float(x)
             f.close()
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
     if opt['tensorboard_log'] is True:
         self.writer = TensorboardLogger(opt)
Ejemplo n.º 11
0
def eval_ppl(opt):
    """Evaluates the the perplexity and f1 of a model (and hits@1 if model has
    ranking enabled.
    """
    dict_agent = build_dict()

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent],
                        default_world=PerplexityWorld)
    world.dict = dict_agent

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['total'] / world.num_examples() * 100, 2),
                report))
            log_time.reset()
    if world.epoch_done():
        print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
Ejemplo n.º 12
0
def detect(opt, printargs=None, print_parser=None):
    """Checks a task for offensive language.
    """
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    bad = OffensiveLanguageDetector()

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        world.parley()
        offensive = False
        for a in world.acts:
            if bad.contains_offensive_language(a.get('text', '')):
                offensive = True
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                if bad.contains_offensive_language(l):
                    offensive = True

        if offensive:
            if opt['display_examples']:
                print(world.display() + "\n~~")
            cnt += 1
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            report = world.report()
            log = {'total': report['total']}
            log['done'] = report['total'] / world.num_examples()
            if log['done'] > 0:
                log['eta'] = int(tot_time / log['done'] - tot_time)
            z = '%.2f' % (100 * log['done'])
            log['done'] = str(z) + '%'
            log['offenses'] = cnt
            print(str(int(tot_time)) + "s elapsed: " + str(log))
            log_time.reset()
    if world.epoch_done():
        print("EPOCH DONE")
    print(
        str(cnt) + " offensive messages found out of " +
        str(world.num_examples()) + " messages.")
    return world.report()
Ejemplo n.º 13
0
def eval_model(opt, parser, printargs=True):
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)
    # Show arguments after loading model
    parser.opt = agent.opt
    if (printargs):
        parser.print_args()
    log_every_n_secs = opt[
        'log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    for _ in range(int(opt['num_examples'])):
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
            log_time.reset()
        if world.epoch_done():
            print("EPOCH DONE")
            break
    print(world.report())
    world.shutdown()
Ejemplo n.º 14
0
 def __init__(self, parser):
     opt = parser.parse_args()
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         if opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict.build_dict(opt)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.total_exs = 0
     self.total_episodes = 0
     self.total_epochs = 0
     self.max_exs = None
     self.max_parleys = None
     self.world_num_exs = self.world.num_examples()
     if self.world_num_exs is not None:
         self.max_exs = opt['num_epochs'] * self.world_num_exs
         self.max_parleys = math.ceil(self.max_exs / opt['batchsize'])
     self.best_valid = 0
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
Ejemplo n.º 15
0
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        if not shared:
            # don't enter this loop for shared instantiations
            self.path = opt['model_file']

            # loss
            self.loss_time = Timer()
            self.losses = []
            self.save_loss_every_n_secs = opt['save_loss_every_n_secs']

            # attention
            if opt['task'] == "babi:All1k":
                self.tasks = {'babi:Task1k:' + str(i): 0 for i in range(1, 21)}
            elif opt['task'] == "babi:All10k":
                self.tasks = {
                    'babi:Task10k:' + str(i): 0
                    for i in range(1, 21)
                }
            else:
                self.tasks = {task: 0 for task in opt['task'].split(',')}
            self.attention_weights = {task: [] for task in self.tasks.keys()}
            self.save_attention_exs = opt['save_attention_exs']
Ejemplo n.º 16
0
 def __init__(self, opt, agents=None, shared=None):
     self.id = opt['task']
     self.opt = copy.deepcopy(opt)
     if shared:
         # Create agents based on shared data.
         self.agents = create_agents_from_shared(shared['agents'])
     else:
         # Add passed in agents to world directly.
         self.agents = agents
     self.max_exs = None
     self.total_exs = 0
     self.total_epochs = 0
     self.total_parleys = 0
     self.time = Timer()
Ejemplo n.º 17
0
def eval_model(opt, printargs=None, print_parser=None):
    """Evaluates a model.

    Arguments:
    opt -- tells the evaluation function how to run
    print_parser -- if provided, prints the options that are set within the
        model after loading the model
    """
    if printargs is not None:
        print('[ Deprecated Warning: eval_model no longer uses `printargs` ]')
        print_parser = printargs
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: eval_model should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        cnt += 1
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
            log_time.reset()
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    report = world.report()
    print(report)
    return report
Ejemplo n.º 18
0
 def __init__(self, parser):
     opt = parser.parse_args()
     # Possibly load from checkpoint
     if opt['load_from_checkpoint'] and opt.get(
             'model_file') and os.path.isfile(opt['model_file'] +
                                              '.checkpoint'):
         opt['init_model'] = opt['model_file'] + '.checkpoint'
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         if opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt[
         'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt[
         'max_train_time'] if opt['max_train_time'] > 0 else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt[
         'log_every_n_secs'] > 0 else float('inf')
     self.val_every_n_secs = opt['validation_every_n_secs'] if opt[
         'validation_every_n_secs'] > 0 else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt[
         'save_every_n_secs'] > 0 else float('inf')
     self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
     self.best_valid = None
     if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                 '.best_valid'):
         with open(opt['model_file'] + ".best_valid", 'r') as f:
             x = f.readline()
             self.best_valid = float(x)
             f.close()
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
Ejemplo n.º 19
0
def run_eval(valid_worlds, opt, datatype, max_exs=-1, write_log=False):
    """
    Eval on validation/test data.

    :param valid_world:
        list of the pre-created validation worlds.
    :param opt:
        the options that specific the task, eval_task, etc
    :param datatype:
        the datatype to use, such as "valid" or "test"
    :param bool write_log:
        specifies to write metrics to file if the model_file is set
    :param int max_exs:
        limits the number of examples if max_exs > 0
    """
    if valid_worlds is None:
        # This isn't the primary worker, so we can just skip evaluation
        return None

    print('[ running eval: ' + datatype + ' ]')
    timer = Timer()
    reports = []
    for v_world in valid_worlds:
        task_report = _run_single_eval(opt, v_world,
                                       max_exs / len(valid_worlds))
        reports.append(task_report)

    tasks = [world.opt['task'] for world in valid_worlds]
    report = aggregate_task_reports(reports,
                                    tasks,
                                    micro=opt.get('aggregate_micro', True))

    metrics = f'{datatype}:{report}'
    print(f'[ eval completed in {timer.time():.2f}s ]')
    print(metrics)

    # write to file
    if write_log and opt.get('model_file'):
        # Write out metrics
        f = open(opt['model_file'] + '.' + datatype, 'a+')
        f.write(metrics + '\n')
        f.close()

    return report
Ejemplo n.º 20
0
def eval_ppl(opt, build_dict):
    """Evaluates the the perplexity of a model.

    See the documentation for this file for more info.

    :param opt: option dict
    :param build_dict: function for building official dictionary.
        note that this function does not use the opt passed into eval_ppl,
        but rather should have hardcoded settings for its dictionary.

    """
    dict_agent = build_dict()

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent], default_world=PerplexityWorld)

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['total'] / world.num_examples() * 100, 3),
                report))
            log_time.reset()
    print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " +str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Ejemplo n.º 21
0
def eval_model(parser, printargs=True):
    random.seed(42)
    opt = parser.parse_args(print_args=False)

    nomodel = False
    # check to make sure the model file exists
    if opt.get('model_file') is None:
        nomodel = True
    elif not os.path.isfile(opt['model_file']):
        raise RuntimeError('WARNING: Model file does not exist, check to make '
                           'sure it is correct: {}'.format(opt['model_file']))

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    if nomodel and hasattr(agent, 'load'):
        # double check that we didn't forget to set model_file on loadable model
        raise RuntimeError('Stopping evaluation because model_file unset but '
                           'model has a `load` function.')
    world = create_task(opt, agent)
    # Show arguments after loading model
    parser.opt = agent.opt
    if (printargs):
        parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        cnt += 1
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
            log_time.reset()
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    report = world.report()
    print(report)
    return report
Ejemplo n.º 22
0
def eval_ppl(opt):
    """Evaluates the the perplexity and f1 of a model (and hits@1 if model has
    ranking enabled.
    """
    dict_agent = build_dict()

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent],
                        default_world=PerplexityWorld)
    world.dict = dict_agent

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['total'] / world.num_examples() * 100, 3),
                report))
            log_time.reset()
    if world.epoch_done():
        print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " + str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Ejemplo n.º 23
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Ejemplo n.º 24
0
class TrainLoop():
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

    def save_model(self, suffix=None):
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return
        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _save_train_stats(self, suffix=None):
        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with open(fn, 'w') as f:
            json.dump(
                {
                    'train_time':
                    self.train_time.time(),
                    'total_epochs':
                    (self._preempted_epochs +
                     num_workers() * self.world.get_total_epochs()),
                    'impatience':
                    self.impatience,
                }, f)

    def validate(self):
        opt = self.opt

        if self.valid_world is None:
            # we need to load the world now
            self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = sync_object(
            run_eval(self.valid_world, opt, 'valid', opt['validation_max_exs'],
                     True))

        # logging
        if opt['tensorboard_log'] is True and is_primary_worker():
            self.writer.add_metrics('valid', int(self.train_time.time()),
                                    valid_report)
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def _average_dicts(self, all_versions):
        # instead of a list-of-dicts with like keys, make a dict-of-lists with
        # keys to reduce
        to_reduce = {}
        for d in all_versions:
            for k, v in d.items():
                to_reduce.setdefault(k, []).append(v)
        # now perform the reduction
        finalized = {}
        for k, values in to_reduce.items():
            if k == 'exs' or k == 'total_skipped_batches':
                # sum across workers
                finalized[k] = np.sum(values)
            elif isinstance(values[0], dict):
                # do the same procedure recursively
                finalized[k] = self._average_dicts(values)
            else:
                # all other cases, take the mean across the workers
                finalized[k] = np.mean(values)
        return finalized

    def _sync_training_metrics(self, metrics):
        """
        Sync training metrics across workers. A handful of special cases are handled
        as exceptions, and the remaining metrics are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return self._average_dicts(all_versions)

    def _nice_format(self, dictionary):
        rounded = {}
        for k, v in dictionary.items():
            if isinstance(v, dict):
                rounded[k] = self._nice_format(v)
            elif isinstance(v, float):
                rounded[k] = round_sigfigs(v, 4)
            else:
                rounded[k] = v
        return rounded

    def _compute_eta(self, epochs_completed, time_elapsed):
        """
        Computes the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        return eta

    def log(self):
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self._sync_training_metrics(self.world.report())
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(np.floor(self.train_time.time())))
        logs.append('total_exs:{}'.format(self._total_exs))

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append('epochs:{}'.format(round(self._total_epochs, 2)))

        time_left = self._compute_eta(self._total_epochs,
                                      self.train_time.time())
        if time_left is not None:
            logs.append('time_left:{}s'.format(max(0, np.ceil(time_left))))

        log = '[ {} ] {}'.format(' '.join(logs),
                                 self._nice_format(train_report))
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] is True and is_primary_worker():
            self.writer.add_metrics('train', self._total_exs, train_report)

    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1
                # print(world.display())

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object(
                    (self.train_time.time(), self.log_time.time(),
                     self.validate_time.time()))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        v_report = run_eval(valid_world, opt, 'valid', write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        return v_report, t_report
Ejemplo n.º 25
0
def eval_ppl(opt, build_dict=None, dict_file=None):
    """Evaluates the the perplexity of a model.

    This uses a dictionary which implements the following functions:
    - tokenize(text): splits string up into list of tokens
    - __in__(text): checks whether dictionary contains a token
    - keys(): returns an iterator over all tokens in the dictionary

    :param opt: option dict
    :param build_dict: function which returns a dictionary class implementing
        the functions above.
    :param dict_file: file used when loading the dictionary class set via the
        "dictionary_class" argument (defaults to
        parlai.core.dict:DictionaryAgent).

    Either build_dict or dict_file must be set (both default to None) to
    determine the dictionary used for the evaluation.
    """
    if not build_dict and not dict_file:
        raise RuntimeError('eval_ppl script either needs a dictionary build '
                           'function or a dictionary file.')

    if build_dict:
        dict_agent = build_dict()
    else:
        dict_opt = copy.deepcopy(opt)
        dict_opt['model'] = dict_opt.get(
            'dictionary_class', 'parlai.core.dict:DictionaryAgent'
        )
        dict_opt['model_file'] = dict_file
        if 'override' in dict_opt:
            del dict_opt['override']
        dict_agent = create_agent(dict_opt, requireModelExists=True)

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent], default_world=PerplexityWorld)

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['exs'] / world.num_examples() * 100, 3),
                report))
            log_time.reset()
    print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " + str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Ejemplo n.º 26
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et',
                       '--evaltask',
                       help=('task to use for valid/test (defaults to the ' +
                             'one used for training if not set)'))
    train.add_argument('-d', '--display-examples', type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    train.add_argument('-lparl',
                       '--log-every-n-parleys',
                       type=int,
                       default=100)
    train.add_argument('-vtim',
                       '--validation-every-n-secs',
                       type=float,
                       default=-1)
    train.add_argument('-vparl',
                       '--validation-every-n-parleys',
                       type=int,
                       default=-1)
    train.add_argument('-vme',
                       '--validation-max-exs',
                       type=int,
                       default=-1,
                       help='max examples to use during validation (default ' +
                       '-1 uses all)')
    train.add_argument(
        '-vp',
        '--validation-patience',
        type=int,
        default=5,
        help=('number of iterations of validation where result ' +
              'does not improve before we stop training'))
    train.add_argument(
        '-vmt',
        '--validation-metric',
        default='accuracy',
        help='key into report table for selecting best validation')
    train.add_argument('-dbf',
                       '--dict-build-first',
                       type='bool',
                       default=True,
                       help='build dictionary first before training agent')
    opt = parser.parse_args()

    # Set logging
    logger = logging.getLogger('DrQA')
    logger.setLevel(logging.INFO)
    fmt = logging.Formatter('%(asctime)s: %(message)s', '%m/%d/%Y %I:%M:%S %p')
    console = logging.StreamHandler()
    console.setFormatter(fmt)
    logger.addHandler(console)
    if 'log_file' in opt:
        logfile = logging.FileHandler(opt['log_file'], 'w')
        logfile.setFormatter(fmt)
        logger.addHandler(logfile)
    logger.info('[ COMMAND: %s ]' % ' '.join(sys.argv))

    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        logger.info("[ building dictionary first... ]")
        build_dict.build_dict(opt)

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    logger.info('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    best_accuracy = 0

    while True:
        world.parley()
        parleys += 1

        if opt['num_epochs'] > 0 and parleys >= max_parleys:
            logger.info('[ num_epochs completed: {} ]'.format(
                opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time(
        ) > opt['max_train_time']:
            logger.info('[ max_train_time elapsed: {} ]'.format(
                train_time.time()))
            break

#        instead of every_n_secs, use n_parleys
#        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
        if opt['log_every_n_parleys'] > 0 and parleys % opt[
                'log_every_n_parleys'] == 0:
            if opt['display_examples']:
                logger.info(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))
            logs.append('parleys:{}'.format(parleys))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()

            if hasattr(train_report, 'get') and train_report.get('total'):
                total_exs += train_report['total']
                logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec = train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            logger.info(log)
            log_time.reset()


#        instead of every_n_secs, use n_parleys
#       if (opt['validation_every_n_secs'] > 0 and
#                    validate_time.time() > opt['validation_every_n_secs']):
        if (opt['validation_every_n_parleys'] > 0
                and parleys % opt['validation_every_n_parleys'] == 0):
            #if True :
            valid_report, valid_world = run_eval(agent,
                                                 opt,
                                                 'valid',
                                                 opt['validation_max_exs'],
                                                 logger=logger)
            #if False :
            if valid_report[opt['validation_metric']] > best_accuracy:
                best_accuracy = valid_report[opt['validation_metric']]
                impatience = 0
                logger.info('[ new best accuracy: ' + str(best_accuracy) +
                            ' ]')
                world.save_agents()
                saved = True
                if best_accuracy == 1:
                    logger.info('[ task solved! stopping. ]')
                    break
            #if True:
            else:
                opt['learning_rate'] *= 0.5
                agent.model.set_lrate(opt['learning_rate'])
                logger.info('[ Decrease learning_rate %.2e]' %
                            opt['learning_rate'])
                impatience += 1
                logger.info(
                    '[ did not beat best accuracy: {} impatience: {} ]'.format(
                        round(best_accuracy, 4), impatience))

            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt[
                    'validation_patience']:
                logger.info('[ ran out of patience! stopping training. ]')
                break
            if opt['learning_rate'] < pow(10, -6):
                logger.info(
                    '[ learning_rate < pow(10,-6) ! stopping training. ]')
                break

    world.shutdown()
    if not saved:
        world.save_agents()
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True, logger=logger)
    run_eval(agent, opt, 'test', write_log=True, logger=logger)
Ejemplo n.º 27
0
def main(parser):
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    with world:
        while True:
            world.parley()
            parleys += 1

            if opt['num_epochs'] > 0 and parleys >= max_parleys:
                print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
                break
            if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
                print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
                break
            if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
                if opt['display_examples']:
                    print(world.display() + '\n~~')

                logs = []
                # time elapsed
                logs.append('time:{}s'.format(math.floor(train_time.time())))
                logs.append('parleys:{}'.format(parleys))

                # get report and update total examples seen so far
                if hasattr(agent, 'report'):
                    train_report = agent.report()
                    agent.reset_metrics()
                else:
                    train_report = world.report()
                    world.reset_metrics()

                if hasattr(train_report, 'get') and train_report.get('total'):
                    total_exs += train_report['total']
                    logs.append('total_exs:{}'.format(total_exs))

                # check if we should log amount of time remaining
                time_left = None
                if opt['num_epochs'] > 0:
                    exs_per_sec = train_time.time() / total_exs
                    time_left = (max_exs - total_exs) * exs_per_sec
                if opt['max_train_time'] > 0:
                    other_time_left = opt['max_train_time'] - train_time.time()
                    if time_left is not None:
                        time_left = min(time_left, other_time_left)
                    else:
                        time_left = other_time_left
                if time_left is not None:
                    logs.append('time_left:{}s'.format(math.floor(time_left)))

                # join log string and add full metrics report to end of log
                log = '[ {} ] {}'.format(' '.join(logs), train_report)

                print(log)
                log_time.reset()

            if (opt['validation_every_n_secs'] > 0 and
                    validate_time.time() > opt['validation_every_n_secs']):
                valid_report, valid_world = run_eval(
                    agent, opt, 'valid', opt['validation_max_exs'],
                    valid_world=valid_world)
                if valid_report[opt['validation_metric']] > best_valid:
                    best_valid = valid_report[opt['validation_metric']]
                    impatience = 0
                    print('[ new best {}: {} ]'.format(
                        opt['validation_metric'], best_valid))
                    world.save_agents()
                    saved = True
                    if opt['validation_metric'] == 'accuracy' and best_valid > 99.5:
                        print('[ task solved! stopping. ]')
                        break
                else:
                    impatience += 1
                    print('[ did not beat best {}: {} impatience: {} ]'.format(
                            opt['validation_metric'], round(best_valid, 4),
                            impatience))
                validate_time.reset()
                if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                    print('[ ran out of patience! stopping training. ]')
                    break
    if not saved:
        # save agent
        world.save_agents()
    elif opt.get('model_file'):
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Ejemplo n.º 28
0
class World(object):
    """Empty parent providing null definitions of API functions for Worlds.
    All children can override these to provide more detailed functionality."""
    def __init__(self, opt, agents=None, shared=None):
        self.id = opt['task']
        self.opt = copy.deepcopy(opt)
        if shared:
            # Create agents based on shared data.
            self.agents = create_agents_from_shared(shared['agents'])
        else:
            # Add passed in agents to world directly.
            self.agents = agents
        self.max_exs = None
        self.total_exs = 0
        self.total_epochs = 0
        self.total_parleys = 0
        self.time = Timer()

    def parley(self):
        """The main method, that does one step of actions for the agents
        in the world. This is empty in the base class.
        """
        pass

    def getID(self):
        """Return the name of the world, typically the task the world encodes."""
        return self.id

    def display(self):
        """Returns a string describing the current state of the world.

        Useful for monitoring and debugging.
        By default, display the messages between the agents."""
        if not hasattr(self, 'acts'):
            return ''
        return display_messages(self.acts)

    def episode_done(self):
        """Whether the episode is done or not."""
        return False

    def epoch_done(self):
        """Whether the epoch is done or not.

        Not all worlds have the notion of an epoch, but this is useful
        for fixed training, validation or test sets.
        """
        return False

    def share(self):
        shared_data = {}
        shared_data['world_class'] = type(self)
        shared_data['opt'] = self.opt
        shared_data['agents'] = self._share_agents()
        return shared_data

    def _share_agents(self):
        """Create shared data for agents so other classes can create the same
        agents without duplicating the data (i.e. sharing parameters).
        """
        if not hasattr(self, 'agents'):
            return None
        shared_agents = [a.share() for a in self.agents]
        return shared_agents

    def get_agents(self):
        """Return the list of agents."""
        return self.agents

    def get_acts(self):
        """Return the last act of each agent."""
        return self.acts

    def get_time(self):
        """Return total training time"""
        return self.time.time()

    def get_total_exs(self):
        """Return total amount of examples seen by world."""
        return self.total_exs

    def get_total_epochs(self):
        """Return total amount of epochs on which the world has trained."""
        return self.total_epochs

    def __enter__(self):
        """Empty enter provided for use with ``with`` statement.

        e.g:

        .. code-block:: python

            with World() as world:
                for n in range(10):
                    n.parley()
        """
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        """After ``with`` statement, call shutdown."""
        silent_exit = isinstance(exc_value, KeyboardInterrupt)
        self.shutdown()
        return silent_exit

    def num_examples(self):
        return 0

    def num_episodes(self):
        return 0

    def reset(self):
        for a in self.agents:
            a.reset()
        self.max_exs = None
        self.total_exs = 0
        self.total_epochs = 0
        self.total_parleys = 0
        self.time.reset()

    def reset_metrics(self):
        for a in self.agents:
            a.reset_metrics()

    def shutdown(self):
        """Perform any cleanup, if appropriate."""
        pass

    def update_counters(self):
        """Update how many epochs have completed"""
        self.total_parleys += 1
        if self.max_exs is None:
            if ('num_epochs' in self.opt and self.opt['num_epochs'] > 0):
                self.max_exs = self.num_examples(
                ) * self.opt['num_epochs'] if self.num_examples() else -1
            else:
                self.max_exs = -1
        # when we know the size of the data
        if self.max_exs > 0:
            self.total_epochs = self.total_parleys * self.opt.get(
                'batchsize', 1) / self.num_examples()
        # when we do not know the size of the data
        else:
            if self.epoch_done():
                self.total_epochs += 1
Ejemplo n.º 29
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                       help=('task to use for valid/test (defaults to the '
                             'one used for training if not set)'))
    train.add_argument('-d', '--display-examples',
                       type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                       type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                       type=float, default=2)
    train.add_argument('-vtim', '--validation-every-n-secs',
                       type=float, default=-1)
    train.add_argument('-vme', '--validation-max-exs',
                       type=int, default=-1,
                       help='max examples to use during validation (default '
                            '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                       type=int, default=5,
                       help=('number of iterations of validation where result'
                             ' does not improve before we stop training'))
    train.add_argument('-vmt', '--validation-metric', default='accuracy',
                       help='key into report table for selecting best '
                            'validation')
    train.add_argument('-dbf', '--dict-build-first',
                       type='bool', default=True,
                       help='build dictionary first before training agent')
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    while True:
        world.parley()
        parleys += 1

        if opt['num_epochs'] > 0 and parleys >= max_parleys:
            print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
            print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
            break
        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
            if opt['display_examples']:
                print(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))
            logs.append('parleys:{}'.format(parleys))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()

            if hasattr(train_report, 'get') and train_report.get('total'):
                total_exs += train_report['total']
                logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec = train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            print(log)
            log_time.reset()

        if (opt['validation_every_n_secs'] > 0 and
                validate_time.time() > opt['validation_every_n_secs']):
            valid_report, valid_world = run_eval(
                agent, opt, 'valid', opt['validation_max_exs'],
                valid_world=valid_world)
            if valid_report[opt['validation_metric']] > best_valid:
                best_valid = valid_report[opt['validation_metric']]
                impatience = 0
                print('[ new best {}: {} ]'.format(
                    opt['validation_metric'], best_valid))
                world.save_agents()
                saved = True
                if opt['validation_metric'] == 'accuracy' and best_valid == 1:
                    print('[ task solved! stopping. ]')
                    break
            else:
                impatience += 1
                print('[ did not beat best {}: {} impatience: {} ]'.format(
                        opt['validation_metric'], round(best_valid, 4),
                        impatience))
            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                print('[ ran out of patience! stopping training. ]')
                break
    world.shutdown()
    if not saved:
        world.save_agents()
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Ejemplo n.º 30
0
    def test_timer(self):
        t = Timer()
        elapsed = t.stop().time()
        assert elapsed > 0

        same = t.time()
        assert elapsed == same

        t.resume()
        time.sleep(0.1)
        more = t.time()
        assert more > elapsed

        other = Timer()
        less = other.reset().time()
        assert less > 0
        assert less < t.time()
Ejemplo n.º 31
0
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        super().__init__(opt, shared)
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.is_combine_attr = (hasattr(self, 'other_task_datafiles')
                                and self.other_task_datafiles)
        self.random_policy = opt.get('random_policy', False)
        self.count_sample = opt.get('count_sample', False)
        self.anti = opt.get('anti', False)

        if self.random_policy:
            random.seed(17)

        if not shared:
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                score_list = [episode[0][2] for episode in self.data.data]
                assert score_list == sorted(score_list)
                num_buckets = opt.get('num_buckets',
                                      int(self.num_episodes() / 10))
                lb_indices = [
                    int(len(score_list) * i / num_buckets)
                    for i in range(num_buckets)
                ]
                lbs = [score_list[idx] for idx in lb_indices]
                bucket_ids = [
                    self.sort_into_bucket(ctrl_val, lbs)
                    for ctrl_val in score_list
                ]
                bucket_cnt = [0 for _ in range(num_buckets)]
                for i in range(num_buckets):
                    bucket_cnt[i] = bucket_ids.count(i)
                self.bucket_cnt = bucket_cnt
            self.lastYs = [None] * self.bsz
            # build multiple task data
            self.tasks = [self.data]

            if self.is_combine_attr:
                print('[ build multiple task data ... ]')
                for datafile in self.other_task_datafiles:
                    task_opt = copy.deepcopy(opt)
                    task_opt['datafile'] = datafile
                    self.tasks.append(
                        DialogData(task_opt,
                                   data_loader=self.setup_data,
                                   cands=self.label_candidates()))
                print('[ build multiple task data done! ]')

                # record the selections of each subtasks
                self.subtasks = opt['subtasks'].split(':')
                self.subtask_counter = OrderedDict()
                self.p_selections = OrderedDict()
                self.c_selections = OrderedDict()
                for t in self.subtasks:
                    self.subtask_counter[t] = 0
                    self.p_selections[t] = []
                    self.c_selections[t] = []

                if self.count_sample and not self.stream:
                    self.sample_counter = OrderedDict()
                    for idx, t in enumerate(self.subtasks):
                        self.sample_counter[t] = [
                            0 for _ in self.tasks[idx].data
                        ]

            # setup the tensorboard log
            if opt['tensorboard_log_teacher'] is True:
                opt['tensorboard_tag'] = 'task'
                teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split(
                    ',')
                opt['tensorboard_metrics'] = ','.join(
                    opt['tensorboard_metrics'].split(',') + teacher_metrics)
                self.writer = TensorboardLogger(opt)

        else:
            self.lastYs = shared['lastYs']
            self.tasks = shared['tasks']
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                self.bucket_cnt = shared['bucket_cnt']
            if 'writer' in shared:
                self.writer = shared['writer']
            if 'subtask_counter' in shared:
                self.subtask_counter = shared['subtask_counter']
            if 'p_selections' in shared:
                self.p_selections = shared['p_selections']
            if 'c_selections' in shared:
                self.c_selections = shared['c_selections']

        # build the policy net, criterion and optimizer here
        self.state_dim = 32 + len(self.tasks)  # hand-craft features
        self.action_dim = len(self.tasks)

        if not shared:
            self.policy = PolicyNet(self.state_dim, self.action_dim)
            self.critic = CriticNet(self.state_dim, self.action_dim)

            init_teacher = get_init_teacher(opt, shared)
            if init_teacher is not None:
                # load teacher parameters if available
                print('[ Loading existing teacher params from {} ]'
                      ''.format(init_teacher))
                states = self.load(init_teacher)
            else:
                states = {}
        else:
            self.policy = shared['policy']
            self.critic = shared['critic']
            states = shared['states']

        if (
                # only build an optimizer if we're training
                'train' in opt.get('datatype', '') and
                # and this is the main model
                shared is None):
            # for policy net
            self.optimizer = self.init_optim(
                [p for p in self.policy.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher'],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler' in states:
                self.scheduler.load_state_dict(states['lr_scheduler'])

            # for critic net
            self.optimizer_critic = self.init_optim(
                [p for p in self.critic.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher_critic'],
                optim_states=states.get('optimizer_critic'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_critic,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler_critic' in states:
                self.scheduler_critic.load_state_dict(
                    states['lr_scheduler_critic'])

            self.critic_criterion = torch.nn.SmoothL1Loss()

        self.reward_metric = opt.get('reward_metric', 'total_metric')
        self.reward_metric_mode = opt.get('reward_metric_mode', 'max')

        self.prev_prev_valid_report = states[
            'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None
        self.prev_valid_report = states[
            'prev_valid_report'] if 'prev_valid_report' in states else None
        self.current_valid_report = states[
            'current_valid_report'] if 'current_valid_report' in states else None
        self.saved_actions = states[
            'saved_actions'] if 'saved_actions' in states else OrderedDict()
        self.saved_state_actions = states[
            'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict(
            )
        if self.use_cuda:
            for k, v in self.saved_actions.items():
                self.saved_actions[k] = v.cuda()
            for k, v in self.saved_state_actions.items():
                self.saved_state_actions[k] = v.cuda()
        self._number_teacher_updates = states[
            '_number_teacher_updates'] if '_number_teacher_updates' in states else 0

        # enable the batch_act
        self.use_batch_act = self.bsz > 1

        self.T = self.opt.get('T', 1000)
        self.c0 = self.opt.get('c0', 0.01)
        self.p = self.opt.get('p', 2)

        # setup the timer
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.action_log_time = Timer()

        self.move_to_cuda()
Ejemplo n.º 32
0
class DefaultTeacher(FbDialogTeacher):
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        super().__init__(opt, shared)
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.is_combine_attr = (hasattr(self, 'other_task_datafiles')
                                and self.other_task_datafiles)
        self.random_policy = opt.get('random_policy', False)
        self.count_sample = opt.get('count_sample', False)
        self.anti = opt.get('anti', False)

        if self.random_policy:
            random.seed(17)

        if not shared:
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                score_list = [episode[0][2] for episode in self.data.data]
                assert score_list == sorted(score_list)
                num_buckets = opt.get('num_buckets',
                                      int(self.num_episodes() / 10))
                lb_indices = [
                    int(len(score_list) * i / num_buckets)
                    for i in range(num_buckets)
                ]
                lbs = [score_list[idx] for idx in lb_indices]
                bucket_ids = [
                    self.sort_into_bucket(ctrl_val, lbs)
                    for ctrl_val in score_list
                ]
                bucket_cnt = [0 for _ in range(num_buckets)]
                for i in range(num_buckets):
                    bucket_cnt[i] = bucket_ids.count(i)
                self.bucket_cnt = bucket_cnt
            self.lastYs = [None] * self.bsz
            # build multiple task data
            self.tasks = [self.data]

            if self.is_combine_attr:
                print('[ build multiple task data ... ]')
                for datafile in self.other_task_datafiles:
                    task_opt = copy.deepcopy(opt)
                    task_opt['datafile'] = datafile
                    self.tasks.append(
                        DialogData(task_opt,
                                   data_loader=self.setup_data,
                                   cands=self.label_candidates()))
                print('[ build multiple task data done! ]')

                # record the selections of each subtasks
                self.subtasks = opt['subtasks'].split(':')
                self.subtask_counter = OrderedDict()
                self.p_selections = OrderedDict()
                self.c_selections = OrderedDict()
                for t in self.subtasks:
                    self.subtask_counter[t] = 0
                    self.p_selections[t] = []
                    self.c_selections[t] = []

                if self.count_sample and not self.stream:
                    self.sample_counter = OrderedDict()
                    for idx, t in enumerate(self.subtasks):
                        self.sample_counter[t] = [
                            0 for _ in self.tasks[idx].data
                        ]

            # setup the tensorboard log
            if opt['tensorboard_log_teacher'] is True:
                opt['tensorboard_tag'] = 'task'
                teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split(
                    ',')
                opt['tensorboard_metrics'] = ','.join(
                    opt['tensorboard_metrics'].split(',') + teacher_metrics)
                self.writer = TensorboardLogger(opt)

        else:
            self.lastYs = shared['lastYs']
            self.tasks = shared['tasks']
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                self.bucket_cnt = shared['bucket_cnt']
            if 'writer' in shared:
                self.writer = shared['writer']
            if 'subtask_counter' in shared:
                self.subtask_counter = shared['subtask_counter']
            if 'p_selections' in shared:
                self.p_selections = shared['p_selections']
            if 'c_selections' in shared:
                self.c_selections = shared['c_selections']

        # build the policy net, criterion and optimizer here
        self.state_dim = 32 + len(self.tasks)  # hand-craft features
        self.action_dim = len(self.tasks)

        if not shared:
            self.policy = PolicyNet(self.state_dim, self.action_dim)
            self.critic = CriticNet(self.state_dim, self.action_dim)

            init_teacher = get_init_teacher(opt, shared)
            if init_teacher is not None:
                # load teacher parameters if available
                print('[ Loading existing teacher params from {} ]'
                      ''.format(init_teacher))
                states = self.load(init_teacher)
            else:
                states = {}
        else:
            self.policy = shared['policy']
            self.critic = shared['critic']
            states = shared['states']

        if (
                # only build an optimizer if we're training
                'train' in opt.get('datatype', '') and
                # and this is the main model
                shared is None):
            # for policy net
            self.optimizer = self.init_optim(
                [p for p in self.policy.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher'],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler' in states:
                self.scheduler.load_state_dict(states['lr_scheduler'])

            # for critic net
            self.optimizer_critic = self.init_optim(
                [p for p in self.critic.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher_critic'],
                optim_states=states.get('optimizer_critic'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_critic,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler_critic' in states:
                self.scheduler_critic.load_state_dict(
                    states['lr_scheduler_critic'])

            self.critic_criterion = torch.nn.SmoothL1Loss()

        self.reward_metric = opt.get('reward_metric', 'total_metric')
        self.reward_metric_mode = opt.get('reward_metric_mode', 'max')

        self.prev_prev_valid_report = states[
            'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None
        self.prev_valid_report = states[
            'prev_valid_report'] if 'prev_valid_report' in states else None
        self.current_valid_report = states[
            'current_valid_report'] if 'current_valid_report' in states else None
        self.saved_actions = states[
            'saved_actions'] if 'saved_actions' in states else OrderedDict()
        self.saved_state_actions = states[
            'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict(
            )
        if self.use_cuda:
            for k, v in self.saved_actions.items():
                self.saved_actions[k] = v.cuda()
            for k, v in self.saved_state_actions.items():
                self.saved_state_actions[k] = v.cuda()
        self._number_teacher_updates = states[
            '_number_teacher_updates'] if '_number_teacher_updates' in states else 0

        # enable the batch_act
        self.use_batch_act = self.bsz > 1

        self.T = self.opt.get('T', 1000)
        self.c0 = self.opt.get('c0', 0.01)
        self.p = self.opt.get('p', 2)

        # setup the timer
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.action_log_time = Timer()

        self.move_to_cuda()

    def move_to_cuda(self):
        if self.use_cuda:
            self.policy.cuda()
            self.critic.cuda()

    @classmethod
    def optim_opts(self):
        """
        Fetch optimizer selection.

        By default, collects everything in torch.optim, as well as importing:
        - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim

        Override this (and probably call super()) to add your own optimizers.
        """
        # first pull torch.optim in
        optims = {
            k.lower(): v
            for k, v in optim.__dict__.items()
            if not k.startswith('__') and k[0].isupper()
        }
        try:
            import apex.optimizers.fused_adam as fused_adam
            optims['fused_adam'] = fused_adam.FusedAdam
        except ImportError:
            pass

        try:
            # https://openreview.net/pdf?id=S1fUpoR5FQ
            from qhoptim.pyt import QHM, QHAdam
            optims['qhm'] = QHM
            optims['qhadam'] = QHAdam
        except ImportError:
            # no QHM installed
            pass

        return optims

    def init_optim(self, params, lr, optim_states=None, saved_optim_type=None):
        """
        Initialize optimizer with teacher parameters.

        :param params:
            parameters from the teacher

        :param optim_states:
            optional argument providing states of optimizer to load

        :param saved_optim_type:
            type of optimizer being loaded, if changed will skip loading
            optimizer states
        """

        opt = self.opt

        # set up optimizer args
        kwargs = {'lr': lr}
        if opt.get('momentum_teacher') > 0 and opt['optimizer_teacher'] in [
                'sgd', 'rmsprop', 'qhm'
        ]:
            # turn on momentum for optimizers that use it
            kwargs['momentum'] = opt['momentum_teacher']
            if opt['optimizer_teacher'] == 'sgd' and opt.get(
                    'nesterov_teacher', True):
                # for sgd, maybe nesterov
                kwargs['nesterov'] = opt.get('nesterov_teacher', True)
            elif opt['optimizer_teacher'] == 'qhm':
                # qhm needs a nu
                kwargs['nu'] = opt.get('nus_teacher', (0.7, ))[0]
        elif opt['optimizer_teacher'] == 'adam':
            # turn on amsgrad for adam
            # amsgrad paper: https://openreview.net/forum?id=ryQu7f-RZ
            kwargs['amsgrad'] = True
        elif opt['optimizer_teacher'] == 'qhadam':
            # set nus for qhadam
            kwargs['nus'] = opt.get('nus_teacher', (0.7, 1.0))
        if opt['optimizer_teacher'] in [
                'adam', 'sparseadam', 'adamax', 'qhadam'
        ]:
            # set betas for optims that use it
            kwargs['betas'] = opt.get('betas_teacher', (0.9, 0.999))

        optim_class = self.optim_opts()[opt['optimizer_teacher']]
        optimizer = optim_class(params, **kwargs)

        if optim_states:
            if saved_optim_type != opt['optimizer_teacher']:
                print('WARNING: not loading optim state since optim class '
                      'changed.')
            else:
                try:
                    optimizer.load_state_dict(optim_states)
                except ValueError:
                    print('WARNING: not loading optim state since model '
                          'params changed.')
                if self.use_cuda:
                    for state in optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
        return optimizer

    def load(self, path):
        """
        Return opt and teacher states.

        TODO: load behaviors should be consistent with function state_dict().
        """
        states = torch.load(path, map_location=lambda cpu, _: cpu)
        if 'policy' in states:
            self.policy.load_state_dict(states['policy'])
        if 'critic' in states:
            self.critic.load_state_dict(states['critic'])
        if 'optimizer' in states and hasattr(self, 'optimizer'):
            self.optimizer.load_state_dict(states['optimizer'])
        if 'optimizer_critic' in states and hasattr(self, 'optimizer_critic'):
            self.optimizer_critic.load_state_dict(states['optimizer_critic'])
        return states

    def share(self):
        shared = super().share()
        if hasattr(self, 'bucket_cnt'):
            shared['bucket_cnt'] = self.bucket_cnt

        shared['tasks'] = self.tasks
        shared['policy'] = self.policy
        shared['critic'] = self.critic

        shared['states'] = {
            'optimizer_type': self.opt['optimizer_teacher'],
            'prev_prev_valid_report': self.prev_prev_valid_report,
            'prev_valid_report': self.prev_valid_report,
            'current_valid_report': self.current_valid_report,
            'saved_actions': self.saved_actions,
            'saved_state_actions': self.saved_state_actions,
        }
        if hasattr(self, 'writer'):
            shared['writer'] = self.writer
        if hasattr(self, 'subtask_counter'):
            shared['subtask_counter'] = self.subtask_counter
        if hasattr(self, 'p_selections'):
            shared['p_selections'] = self.p_selections
        if hasattr(self, 'c_selections'):
            shared['c_selections'] = self.c_selections
        return shared

    @staticmethod
    def sort_into_bucket(val, bucket_lbs):
        """
        Returns the highest bucket such that val >= lower bound for that bucket.

        Inputs:
          val: float. The value to be sorted into a bucket.
          bucket_lbs: list of floats, sorted ascending.

        Returns:
          bucket_id: int in range(num_buckets); the bucket that val belongs to.
        """
        num_buckets = len(bucket_lbs)
        for bucket_id in range(num_buckets - 1, -1, -1):  # iterate descending
            lb = bucket_lbs[bucket_id]
            if val >= lb:
                return bucket_id
        raise ValueError('val %f is not >= any of the lower bounds: %s' %
                         (val, bucket_lbs))

    def pace_function(self, states, sum_num, T=1000, c0=0.01, p=2):
        train_step = states['train_step']
        progress = self.root_p_pace(train_step, T, c0, p)
        return int(sum_num * progress)

    @staticmethod
    def root_p_pace(timestep, T=1000, c0=0.01, p=2):
        root_p = math.pow(
            timestep * (1 - math.pow(c0, p)) / T + math.pow(c0, p), 1.0 / p)
        return min(1.0, root_p)

    def act(self, observation=None, task_idx=0):
        """Send new dialog message."""
        if not hasattr(self, 'epochDone'):
            # reset if haven't yet
            self.reset()

        # get next example, action is episode_done dict if already out of exs
        action, self.epochDone = self.next_example(observation=observation,
                                                   task_idx=task_idx)
        action['id'] = self.getID()

        # remember correct answer if available
        self.lastY = action.get('labels', action.get('eval_labels', None))
        if ((not self.datatype.startswith('train')
             or 'evalmode' in self.datatype) and 'labels' in action):
            # move labels to eval field so not used for training
            # but this way the model can use the labels for perplexity or loss
            action = action.copy()
            labels = action.pop('labels')
            if not self.opt.get('hide_labels', False):
                action['eval_labels'] = labels

        return action

    def _cry_for_missing_in_obs(self, something):
        raise RuntimeError(
            "{} is needed to include in observations to build states!".format(
                something))

    def _build_states(self, observations):
        for key in ['train_step', 'train_report', 'loss_desc', 'prob_desc']:
            if key not in observations[0]:
                self._cry_for_missing_in_obs(key)

        train_step = observations[0]['train_step']  # scala
        train_step = min(train_step / self.T, 1)
        train_report = observations[0]['train_report']
        nll_loss = train_report.get('nll_loss', 0) / 10  # scala
        loss_desc = observations[0]['loss_desc']
        loss_desc = F.normalize(loss_desc, p=2, dim=-1)

        prob_desc = observations[0]['prob_desc']
        prob_desc = F.normalize(prob_desc, p=2, dim=-1)

        if hasattr(self, 'subtask_counter'):
            subtask_progress = self.subtask_counter.values()
            max_min = max(subtask_progress) - min(subtask_progress)
            subtask_progress = [
                (item - min(subtask_progress)) / max_min if max_min > 0 else 0
                for item in subtask_progress
            ]
        else:
            subtask_progress = [0]
        subtask_progress = torch.FloatTensor(subtask_progress)
        if self.use_cuda:
            subtask_progress = subtask_progress.cuda()

        prev_valid_report = self.prev_valid_report
        if prev_valid_report is None:
            prev_valid_report = {}

        bleu = prev_valid_report.get('bleu', 0)
        valid_nll_loss = prev_valid_report.get('nll_loss', 0) / 10
        dist_1_ratio = prev_valid_report.get('dist_1_ratio', 0)
        dist_2_ratio = prev_valid_report.get('dist_2_ratio', 0)
        dist_3_ratio = prev_valid_report.get('dist_3_ratio', 0)
        embed_avg = prev_valid_report.get('embed_avg', 0)
        embed_greedy = prev_valid_report.get('embed_greedy', 0)
        embed_extrema = prev_valid_report.get('embed_extrema', 0)
        embed_coh = prev_valid_report.get('embed_coh', 0)
        intra_dist_1 = prev_valid_report.get('intra_dist_1', 0) / 10
        intra_dist_2 = prev_valid_report.get('intra_dist_2', 0) / 10
        intra_dist_3 = prev_valid_report.get('intra_dist_3', 0) / 10
        response_length = prev_valid_report.get(
            'response_length', 0) / self.opt.get('label_truncate', 100)
        # sent_entropy_uni = prev_valid_report.get('sent_entropy_uni', 0) / 100
        # sent_entropy_bi = prev_valid_report.get('sent_entropy_bi', 0) / 100
        # sent_entropy_tri = prev_valid_report.get('sent_entropy_tri', 0) / 100
        word_entropy_uni = prev_valid_report.get('word_entropy_uni', 0) / 100
        word_entropy_bi = prev_valid_report.get('word_entropy_bi', 0) / 100
        word_entropy_tri = prev_valid_report.get('word_entropy_tri', 0) / 100
        states = torch.FloatTensor([
            train_step,
            nll_loss,
            bleu,
            valid_nll_loss,
            dist_1_ratio,
            dist_2_ratio,
            dist_3_ratio,
            embed_avg,
            embed_greedy,
            embed_extrema,
            embed_coh,
            intra_dist_1,
            intra_dist_2,
            intra_dist_3,
            response_length,
            # sent_entropy_uni, sent_entropy_bi, sent_entropy_tri,
            word_entropy_uni,
            word_entropy_bi,
            word_entropy_tri
        ])
        if self.use_cuda:
            states = states.cuda()
        states = torch.cat([states, loss_desc, prob_desc, subtask_progress],
                           dim=-1).unsqueeze(dim=0)
        return states

    def __uniform_weights(self):
        w = 1 / len(self.tasks)
        weights = torch.FloatTensor([w] * len(self.tasks))
        if self.use_cuda:
            weights = weights.cuda()
        return weights.unsqueeze(dim=0)

    def __load_training_batch(self, observations):
        if observations and len(
                observations) > 0 and observations[0] and self.is_combine_attr:
            if not self.random_policy:
                with torch.no_grad():
                    current_states = self._build_states(observations)
                action_probs = self.policy(current_states)
                if self.action_log_time.time() > self.log_every_n_secs and len(
                        self.tasks) > 1:
                    with torch.no_grad():
                        # log the action distributions
                        action_p = ','.join([
                            str(round_sigfigs(x, 4))
                            for x in action_probs[0].data.tolist()
                        ])
                        log = '[ {} {} ]'.format('Action probs:', action_p)
                        print(log)
                        self.action_log_time.reset()
                sample_from = Categorical(action_probs[0])
                action = sample_from.sample()
                train_step = observations[0]['train_step']
                self.saved_actions[train_step] = sample_from.log_prob(action)
                self.saved_state_actions[train_step] = torch.cat(
                    [current_states, action_probs], dim=1)
                selected_task = action.item()
                self.subtask_counter[self.subtasks[selected_task]] += 1

                probs = action_probs[0].tolist()
                selection_report = {}
                for idx, t in enumerate(self.subtasks):
                    selection_report['p_{}'.format(t)] = probs[idx]
                    self.p_selections[t].append(probs[idx])
                    selection_report['c_{}'.format(
                        t)] = self.subtask_counter[t]
                    self.c_selections[t].append(self.subtask_counter[t])
                self.writer.add_metrics(setting='Teacher/task_selection',
                                        step=train_step,
                                        report=selection_report)
            else:
                selected_task = random.choice(range(len(self.tasks)))
                self.subtask_counter[self.subtasks[selected_task]] += 1
        else:
            selected_task = 0

        return self.__load_batch(observations, task_idx=selected_task)

    def __load_batch(self, observations, task_idx=0):
        if observations is None:
            observations = [None] * self.bsz
        bsz = len(observations)

        batch = []
        # Sample from multiple tasks using the policy net
        for idx in range(bsz):
            batch.append(self.act(observations[idx], task_idx=task_idx))
        return batch

    def batch_act(self, observations):
        """
        Returns an entire batch of examples instead of just one.
        """
        if not hasattr(self, 'epochDone'):
            # reset if haven't yet
            self.reset()
        if self.opt['datatype'] == 'train':
            batch = self.__load_training_batch(observations)
        else:
            batch = self.__load_batch(observations)

        # pad batch
        if len(batch) < self.bsz:
            batch += [{
                'episode_done': True,
                'id': self.getID()
            }] * (self.bsz - len(batch))

        # remember correct answer if available (for padding, None)
        for i, ex in enumerate(batch):
            if 'labels' in ex:
                labels = ex['labels']
                self.lastYs[i] = labels
                if not self.datatype.startswith(
                        'train') or 'evalmode' in self.datatype:
                    del ex['labels']
                    if not self.opt.get('hide_labels', False):
                        ex['eval_labels'] = labels
            else:
                self.lastYs[i] = ex.get('eval_labels', None)

        return batch

    def next_example(self, observation=None, task_idx=0):
        """
        Returns the next example.

        If there are multiple examples in the same episode, returns the next
        one in that episode. If that episode is over, gets a new episode index
        and returns the first example of that episode.
        """
        if self.stream:
            action, epoch_done = self.tasks[task_idx].get()
        else:
            if self.episode_done:
                self.episode_idx = self.next_episode_idx()
                self.entry_idx = 0
            else:
                self.entry_idx += 1

            if self.episode_idx >= self.num_episodes():
                return {'episode_done': True}, True

            if observation is None or self.opt['datatype'] != 'train':
                # The first step of the training or validation mode
                sampled_episode_idx = self.episode_idx
                sampled_entry_idx = self.entry_idx
            else:
                # --------------- pick the sample according to the pace function -----------
                pace_by = self.opt.get('pace_by', 'sample')

                if pace_by == 'sample':
                    sum_num = self.num_episodes()
                elif pace_by == 'bucket':
                    sum_num = len(self.bucket_cnt)
                else:
                    raise ValueError('pace_by must be {} or {}!'.format(
                        'sample', 'bucket'))

                states4pace_func = observation
                if hasattr(self, 'subtask_counter'):
                    states4pace_func = {
                        'train_step':
                        self.subtask_counter[self.subtasks[task_idx]]
                    }

                threshold = self.pace_function(states4pace_func, sum_num,
                                               self.T, self.c0, self.p)
                if pace_by == 'sample':
                    stop_step = threshold
                elif pace_by == 'bucket':
                    stop_step = sum(self.bucket_cnt[:threshold])
                else:
                    raise ValueError('pace_by must be {} or {}!'.format(
                        'sample', 'bucket'))

                stop_step = self.num_episodes(
                ) if stop_step > self.num_episodes() else stop_step
                # sampled_episode_idx = random.choice(list(range(self.num_episodes()))[:stop_step])
                sampled_episode_idx = np.random.choice(stop_step)
                sampled_entry_idx = 0  # make sure the episode only contains one entry

                if self.anti:
                    sampled_episode_idx = self.num_episodes(
                    ) - 1 - sampled_episode_idx

            if self.count_sample:
                self.sample_counter[
                    self.subtasks[task_idx]][sampled_episode_idx] += 1

            ex = self.get(sampled_episode_idx,
                          sampled_entry_idx,
                          task_idx=task_idx)

            if observation is None or self.opt['datatype'] != 'train':
                self.episode_done = ex.get('episode_done', False)
                if (not self.random and self.episode_done
                        and self.episode_idx + self.opt.get("batchsize", 1) >=
                        self.num_episodes()):
                    epoch_done = True
                else:
                    epoch_done = False
            else:
                # in the setting of curriculum leaning, samples are not uniformly
                # picked from the training set, so, the epoch records here make no sense.
                epoch_done = False

            action = ex

        return action, epoch_done

    def get(self, episode_idx, entry_idx=0, task_idx=0):
        return self.tasks[task_idx].get(episode_idx, entry_idx)[0]

    def update_params(self):
        self._number_teacher_updates += 1
        if self.opt.get('gradient_clip_teacher', -1) > 0:
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(),
                                           self.opt['gradient_clip_teacher'])

        self.optimizer.step()

    def update_critic_params(self):
        if self.opt.get('gradient_clip_teacher', -1) > 0:
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                           self.opt['gradient_clip_teacher'])

        self.optimizer_critic.step()

    def receive_metrics(self, metrics_dict):
        if self.is_combine_attr and not self.random_policy:
            assert self.reward_metric in metrics_dict, '{} is not in the metrics_dict!'.format(
                self.reward_metric)
            self.prev_prev_valid_report = self.prev_valid_report
            self.prev_valid_report = self.current_valid_report
            self.current_valid_report = metrics_dict
            delt_reward = None
            if self.prev_prev_valid_report and self.prev_valid_report and self.current_valid_report:
                delt_reward1 = self.current_valid_report[
                    self.reward_metric] - self.prev_valid_report[
                        self.reward_metric]
                delt_reward0 = self.prev_valid_report[
                    self.reward_metric] - self.prev_prev_valid_report[
                        self.reward_metric]
                if self.reward_metric_mode == 'min':
                    delt_reward1 = -delt_reward1
                    delt_reward0 = -delt_reward0
                delt_reward = delt_reward1 / (delt_reward0 + 1e-6) - 1
            if delt_reward and len(self.saved_actions) > 0 and len(
                    self.saved_state_actions) > 0:
                reward = torch.clamp(torch.FloatTensor([delt_reward]), -10, 10)
                if self.use_cuda:
                    reward = reward.cuda()

                with torch.no_grad():
                    batch_state_actions = torch.cat(list(
                        self.saved_state_actions.values()),
                                                    dim=0)
                    if self.use_cuda:
                        batch_state_actions = batch_state_actions.cuda()
                    estimate_rewards = self.critic(
                        batch_state_actions).squeeze()
                    advantages = reward - estimate_rewards

                    # rescale the rewards by ranking
                    episode_len = len(advantages)
                    ranks = torch.FloatTensor(
                        list(
                            reversed(
                                ss.rankdata(advantages.cpu(),
                                            method='dense')))).unsqueeze(dim=1)
                    rescaled_rewards = torch.sigmoid(
                        12 * (0.5 - ranks / episode_len))

                rescaled_rewards = [r.item() for r in rescaled_rewards]
                policy_loss = []
                idx = 0
                for model_train_step, log_prob in self.saved_actions.items():
                    policy_loss.append(-log_prob.unsqueeze(dim=0) *
                                       rescaled_rewards[idx])
                    idx += 1
                policy_loss = torch.cat(policy_loss).sum()

                # regularization term regarding action distribution
                bsz = batch_state_actions.size(0)
                action_probs = torch.cat(list(
                    self.saved_state_actions.values()),
                                         dim=0).narrow(1, self.state_dim,
                                                       self.action_dim)
                action_ent = torch.sum(
                    -action_probs * torch.log(action_probs)) / bsz

                self.policy.train()
                self.optimizer.zero_grad()
                policy_loss = policy_loss + self.opt.get('reg_action',
                                                         0.001) * (-action_ent)
                policy_loss.backward()
                self.update_params()

                # lr_scheduler step on teacher loss
                policy_loss_item = policy_loss.item()
                if self.opt.get('optimizer_teacher', '') == 'sgd':
                    self.scheduler.step(policy_loss_item)

                # training on the critic
                self.critic.train()
                self.optimizer_critic.zero_grad()

                batch_values = self.critic(batch_state_actions)
                critic_target = torch.FloatTensor(bsz, 1)
                critic_target = critic_target.fill_(reward.item())
                if self.use_cuda:
                    critic_target = critic_target.cuda()
                critic_loss = self.critic_criterion(batch_values,
                                                    critic_target)
                critic_loss.backward()
                self.update_critic_params()
                critic_loss_item = critic_loss.item()
                if self.opt.get('optimizer_teacher', '') == 'sgd':
                    self.scheduler_critic.step(critic_loss_item)

                # log something
                print(
                    '[ reward: {}; mean_advantage_reward: {}; policy loss: {};'
                    ' critic loss: {}; action ent: {}; episode length: {} ]'.
                    format(reward.item(), np.mean(advantages.tolist()),
                           policy_loss_item, critic_loss_item,
                           action_ent.item(), len(self.saved_actions)))

                report = {
                    'reward': reward.item(),
                    'mean_advantage_reward': np.mean(advantages.tolist()),
                    'policy_loss': policy_loss_item,
                    'critic_loss': critic_loss_item,
                    'action_ent': action_ent.item(),
                }
                self.writer.add_metrics(setting='Teacher/receive_metrics',
                                        step=self._number_teacher_updates,
                                        report=report)
                # clear history actions
                self.saved_actions.clear()
                self.saved_state_actions.clear()

    def state_dict(self):
        """
        Get the state dict for saving

        TODO: save more teacher-related states for reloading
        """
        states = {}
        if hasattr(self, 'policy'):  # save model params
            if hasattr(self.policy, 'module'):
                # did we wrap in a DistributedDataParallel
                states['policy'] = self.policy.module.state_dict()
            else:
                states['policy'] = self.policy.state_dict()

        if hasattr(self, 'critic'):  # save model params
            if hasattr(self.critic, 'module'):
                # did we wrap in a DistributedDataParallel
                states['critic'] = self.critic.module.state_dict()
            else:
                states['critic'] = self.critic.state_dict()

        if hasattr(self, 'optimizer'):  # save optimizer params
            states['optimizer'] = self.optimizer.state_dict()
            states['optimizer_type'] = self.opt['optimizer_teacher']
        if hasattr(self, 'optimizer_critic'):
            states['optimizer_critic'] = self.optimizer_critic.state_dict()

        if getattr(self, 'scheduler', None):
            states['lr_scheduler'] = self.scheduler.state_dict()
        if getattr(self, 'scheduler_critic', None):
            states['lr_scheduler_critic'] = self.scheduler_critic.state_dict()

        states['prev_prev_valid_report'] = self.prev_prev_valid_report
        states['prev_valid_report'] = self.prev_valid_report
        states['current_valid_report'] = self.current_valid_report
        states['saved_actions'] = self.saved_actions
        states['saved_state_actions'] = self.saved_state_actions

        states['_number_teacher_updates'] = self._number_teacher_updates

        return states

    def save(self, path=None):
        if path:
            teacher_path = path
        else:
            model_file = self.opt.get('model_file', None)
            if model_file:
                teacher_path = model_file + '.teacher'
            else:
                teacher_path = None

        if teacher_path:
            states = self.state_dict()
            if states:
                with open(teacher_path, 'wb') as write:
                    torch.save(states, write)
                # save opt file
                with open(teacher_path + '.opt', 'w',
                          encoding='utf-8') as handle:
                    json.dump(self.opt, handle)
                    # for convenience of working with jq, make sure there's a newline
                    handle.write('\n')

            if self.count_sample:
                # save sample count info
                for task_name, task_val in self.sample_counter.items():
                    with open(teacher_path +
                              '.sample_count.{}'.format(task_name),
                              'w',
                              encoding='utf-8') as f:
                        f.write('\n'.join([str(item) for item in task_val]))

            self.write_selections('p_selections', teacher_path)
            self.write_selections('c_selections', teacher_path)

    def write_selections(self, selections, teacher_path):
        if hasattr(self, selections):
            with open(teacher_path + '.{}'.format(selections),
                      'w',
                      encoding='utf-8') as f:
                f.write('\t'.join(self.subtasks))
                f.write('\n')
                for idx in range(
                        len(getattr(self, selections)[self.subtasks[0]])):
                    p_line = []
                    for t in self.subtasks:
                        p_line.append(str(getattr(self, selections)[t][idx]))
                    f.write('\t'.join(p_line))
                    f.write('\n')
Ejemplo n.º 33
0
class TrainLoop():
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt[
            'max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt[
            'log_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt[
            'validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt[
            'save_every_n_secs'] > 0 else float('inf')
        self.val_every_n_epochs = opt['validation_every_n_epochs'] if opt[
            'validation_every_n_epochs'] > 0 else float('inf')
        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

    def validate(self):
        opt = self.opt
        # run evaluation on valid set
        valid_report, self.valid_world = run_eval(self.agent,
                                                  opt,
                                                  'valid',
                                                  opt['validation_max_exs'],
                                                  valid_world=self.valid_world)

        # logging
        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('valid',
                                    int(math.floor(self.train_time.time())),
                                    valid_report)
        # saving
        if opt.get('model_file') and opt.get('save_after_valid'):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.agent.save(opt['model_file'] + '.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid:
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.agent.save(opt['model_file'])
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[
                    'validation_cutoff']:
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if opt['validation_patience'] > 0 and self.impatience >= opt[
                'validation_patience']:
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def log(self):
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report(compute_time=True)
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(math.floor(self.train_time.time())))
        total_exs = self.world.get_total_exs()
        logs.append('total_exs:{}'.format(total_exs))

        exs_per_ep = self.world.num_examples()
        if exs_per_ep:
            logs.append('epochs:{}'.format(round(total_exs / exs_per_ep, 2)))

        if 'time_left' in train_report:
            logs.append('time_left:{}s'.format(
                math.floor(train_report.pop('time_left', ""))))

        log = '[ {} ] {}'.format(' '.join(logs), train_report)
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('train', int(logs[1].split(":")[1]),
                                    train_report)

    def train(self):
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # check counters and timers
                if world.get_total_epochs() >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, self.train_time.time()))
                    break
                if self.train_time.time() > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(
                        self.train_time.time()))
                    break
                if self.log_time.time() > self.log_every_n_secs:
                    self.log()
                if self.validate_time.time() > self.val_every_n_secs:
                    stop_training = self.validate()
                    if stop_training:
                        break
                if world.get_total_epochs(
                ) - self.last_valid_epoch >= self.val_every_n_epochs:
                    stop_training = self.validate()
                    self.last_valid_epoch = world.get_total_epochs()
                    if stop_training:
                        break
                if self.save_time.time() > self.save_every_n_secs and opt.get(
                        'model_file'):
                    print("[ saving model checkpoint: " + opt['model_file'] +
                          ".checkpoint ]")
                    self.agent.save(opt['model_file'] + '.checkpoint')
                    self.save_time.reset()

        if not self.saved:
            # save agent
            self.agent.save(opt['model_file'])
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True)
        t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True)
        v_world.shutdown()
        t_world.shutdown()
        return v_report, t_report
Ejemplo n.º 34
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                        help=('task to use for valid/test (defaults to the ' +
                              'one used for training if not set)'))
    train.add_argument('-d', '--display-examples',
                        type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=int, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                        type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                        type=float, default=2)
    train.add_argument('-vtim', '--validation-every-n-secs',
                        type=float, default=-1)
    train.add_argument('-vme', '--validation-max-exs',
                        type=int, default=-1,
                        help='max examples to use during validation (default ' +
                             '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                        type=int, default=5,
                        help=('number of iterations of validation where result '
                              + 'does not improve before we stop training'))
    train.add_argument('-dbf', '--dict-build-first',
                        type='bool', default=True,
                        help='build dictionary first before training agent')
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    best_accuracy = 0
    impatience = 0
    saved = False
    while True:
        world.parley()
        if opt['num_epochs'] > 0 and total_exs >= max_exs:
            print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
            print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
            break
        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
            if opt['display_examples']:
                print(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()
            total_exs += train_report['total']
            logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec =  train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            print(log)
            log_time.reset()

        if (opt['validation_every_n_secs'] > 0 and
                validate_time.time() > opt['validation_every_n_secs']):
            valid_report = run_eval(agent, opt, 'valid', True, opt['validation_max_exs'])
            if valid_report['accuracy'] > best_accuracy:
                best_accuracy = valid_report['accuracy']
                impatience = 0
                print('[ new best accuracy: ' + str(best_accuracy) +  ' ]')
                if opt['model_file']:
                    agent.save(opt['model_file'])
                    saved = True
                if best_accuracy == 1:
                    print('[ task solved! stopping. ]')
                    break
            else:
                impatience += 1
                print('[ did not beat best accuracy: {} impatience: {} ]'.format(
                        round(best_accuracy, 4), impatience))
            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                print('[ ran out of patience! stopping training. ]')
                break
    world.shutdown()
    if not saved:
        if opt['model_file']:
            agent.save(opt['model_file'])
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid')
    run_eval(agent, opt, 'test')
Ejemplo n.º 35
0
class TrainLoop():
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get(
                    'model_file_transmitter') and opt.get(
                        'model_file_receiver'):
                opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt[
                    'model_file_receiver'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=False)

        # Create model and assign it to the specified task
        print("[ create meta-agent ... ]")
        self.agent = create_agent(opt)
        print("[ create agent A ... ]")
        shared = self.agent.share()
        self.agent_a = create_agent_from_shared(shared)
        self.agent_a.set_id(suffix=' A')
        print("[ create agent B ... ]")
        self.agent_b = create_agent_from_shared(shared)
        # self.agent_b = create_agent(opt)
        self.agent_b.set_id(suffix=' B')
        # self.agent_a.copy(self.agent, 'transmitter')
        # self.agent_b.copy(self.agent, 'transmitter')
        self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b])

        # TODO: if batch, it is also not parallel
        # self.world = BatchSelfPlayWorld(opt, self_play_world)

        self.train_time = Timer()
        self.train_dis_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys_episode = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt[
            'max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt[
            'log_every_n_secs'] > 0 else float('inf')
        self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt[
            'train_display_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt[
            'validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt[
            'save_every_n_secs'] > 0 else float('inf')
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file_transmitter') and os.path.isfile(
                opt['model_file_transmitter'] + '.best_valid'):
            with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

    def validate(self):
        opt = self.opt
        valid_report, self.valid_world = run_eval(self.agent,
                                                  opt,
                                                  'valid',
                                                  opt['validation_max_exs'],
                                                  valid_world=self.valid_world)
        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('valid', self.parleys_episode,
                                    valid_report)
        if opt.get('model_file_transmitter') and opt.get('save_after_valid'):
            print("[ saving transmitter checkpoint: " +
                  opt['model_file_transmitter'] + ".checkpoint ]")
            self.agent.save(component='transmitter')
        # if opt.get('model_file_receiver') and opt.get('save_after_valid'):
        #     print("[ saving receiver checkpoint: " + opt['model_file_receiver'] + ".checkpoint ]")
        #     self.agent.save(component='receiver')
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]
        if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid:
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                # the fine-tuned transmitter part is actually what we want for PSquare bot
                self.agent.save()
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True

            if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[
                    'validation_cutoff']:
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()
        if 0 < opt['validation_patience'] <= self.impatience:
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def log(self):
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(math.floor(self.train_time.time())))
        logs.append('parleys:{}'.format(self.parleys_episode))

        if 'time_left' in train_report:
            logs.append('time_left:{}s'.format(
                math.floor(train_report.pop('time_left', ""))))
        if 'num_epochs' in train_report:
            logs.append('num_epochs:{}'.format(
                train_report.pop('num_epochs', '')))
        log = '[ {} ] {}'.format(' '.join(logs), train_report)
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('train', self.parleys_episode,
                                    train_report)

    def train(self):
        # print('#### Validating at {} training episode '.format(self.parleys_episode))
        # self.validate()
        opt = self.opt
        world = self.world
        with world:
            while True:
                self.parleys_episode += 1
                if self.parleys_episode % 100 == 0:
                    print('#### Training {} episode '.format(
                        self.parleys_episode))

                if self.train_dis_time.time() > self.train_dis_every_n_secs:
                    is_display = True
                    # clear to zero
                    self.train_dis_time.reset()
                else:
                    is_display = False

                world.parley_episode(is_training=True, is_display=is_display)

                if world.get_total_epochs() >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, self.train_time.time()))
                    break

                if self.train_time.time() > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(
                        self.train_time.time()))
                    break

                if self.log_time.time() > self.log_every_n_secs:
                    self.log()

                if self.validate_time.time() > self.val_every_n_secs:
                    print('#### Validating at {} training episode '.format(
                        self.parleys_episode))
                    stop_training = self.validate()
                    if stop_training:
                        break

                if self.save_time.time() > self.save_every_n_secs:
                    if opt.get('model_file_transmitter'):
                        print("[ saving transmitter checkpoint: " +
                              opt['model_file_transmitter'] + ".checkpoint ]")
                        self.agent.save(opt['model_file_transmitter'] +
                                        '.checkpoint',
                                        component='transmitter')
                    if opt.get('model_file_receiver'):
                        print("[ saving receiver checkpoint: " +
                              opt['model_file_receiver'] + ".checkpoint ]")
                        self.agent.save(opt['model_file_receiver'] +
                                        '.checkpoint',
                                        component='receiver')
                    self.save_time.reset()

        if not self.saved:
            # save agent
            # self.agent.save(component='transmitter')
            self.agent.save()
            # self.agent.save(component='receiver') # TODO: API for save all components
        elif opt.get('model_file_transmitter') and opt.get(
                'model_file_receiver'
        ):  # TODO: check if both components are necessary
            # reload best validation model
            self.agent = create_agent(opt)

        v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True)
        t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True)
        v_world.shutdown()
        t_world.shutdown()
        return v_report, t_report
Ejemplo n.º 36
0
class TrainLoop():
    def __init__(self, parser):
        opt = parser.parse_args()
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
        self.best_valid = 0
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

    def validate(self):
        opt = self.opt
        valid_report, self.valid_world = run_eval(
            self.agent, opt, 'valid', opt['validation_max_exs'],
            valid_world=self.valid_world)
        if valid_report[opt['validation_metric']] > self.best_valid:
            self.best_valid = valid_report[opt['validation_metric']]
            self.impatience = 0
            print('[ new best {}: {} ]'.format(
                opt['validation_metric'], self.best_valid))
            self.world.save_agents()
            self.saved = True
            if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']:
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                    opt['validation_metric'], round(self.best_valid, 4),
                    self.impatience))
        self.validate_time.reset()
        if opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']:
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def log(self):
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        if hasattr(self.agent, 'report'):
            train_report = self.agent.report()
            self.agent.reset_metrics()
        else:
            train_report = self.world.report(compute_time=True)
            self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(math.floor(self.train_time.time())))
        logs.append('parleys:{}'.format(self.parleys))

        if 'time_left' in train_report:
            logs.append('time_left:{}s'.format(
                         math.floor(train_report.pop('time_left', ""))))
        if 'num_epochs' in train_report:
            logs.append('num_epochs:{}'.format(
                         train_report.pop('num_epochs', '')))
        log = '[ {} ] {}'.format(' '.join(logs), train_report)
        print(log)
        self.log_time.reset()

    def train(self):
        opt = self.opt
        world = self.world
        with world:
            while True:
                world.parley()
                self.parleys += 1

                if world.get_total_epochs() >= self.max_num_epochs:
                    self.log()
                    print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
                        self.max_num_epochs, self.train_time.time()))
                    break
                if self.train_time.time() > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(self.train_time.time()))
                    break
                if self.log_time.time() > self.log_every_n_secs:
                    self.log()
                if self.validate_time.time() > self.val_every_n_secs:
                    stop_training = self.validate()
                    if stop_training:
                        break
                if self.save_time.time() > self.save_every_n_secs:
                    print("[ saving model: " + opt['model_file'] + " ]")
                    world.save_agents()
                    self.save_time.reset()

        if not self.saved:
            # save agent
            world.save_agents()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        _rep, wrld = run_eval(self.agent, opt, 'valid', write_log=True)
        wrld.shutdown()  # may need to shut down threads, remote connections
        _rep, wrld = run_eval(self.agent, opt, 'test', write_log=True)
        wrld.shutdown()  # may need to shut down threads, remote connections
Ejemplo n.º 37
0
class World(object):
    """Empty parent providing null definitions of API functions for Worlds.
    All children can override these to provide more detailed functionality."""

    def __init__(self, opt, agents=None, shared=None):
        self.id = opt['task']
        self.opt = copy.deepcopy(opt)
        if shared:
            # Create agents based on shared data.
            self.agents = create_agents_from_shared(shared['agents'])
        else:
            # Add passed in agents to world directly.
            self.agents = agents
        self.max_exs = None
        self.total_exs = 0
        self.total_epochs = 0
        self.total_parleys = 0
        self.time = Timer()

    def parley(self):
        """The main method, that does one step of actions for the agents
        in the world. This is empty in the base class.
        """
        pass

    def getID(self):
        """Return the name of the world, typically the task the world encodes."""
        return self.id

    def display(self):
        """Returns a string describing the current state of the world.

        Useful for monitoring and debugging.
        By default, display the messages between the agents."""
        if not hasattr(self, 'acts'):
            return ''
        return display_messages(self.acts)

    def episode_done(self):
        """Whether the episode is done or not."""
        return False

    def epoch_done(self):
        """Whether the epoch is done or not.

        Not all worlds have the notion of an epoch, but this is useful
        for fixed training, validation or test sets.
        """
        return False

    def share(self):
        shared_data = {}
        shared_data['world_class'] = type(self)
        shared_data['opt'] = self.opt
        shared_data['agents'] = self._share_agents()
        return shared_data

    def _share_agents(self):
        """Create shared data for agents so other classes can create the same
        agents without duplicating the data (i.e. sharing parameters).
        """
        if not hasattr(self, 'agents'):
            return None
        shared_agents = [a.share() for a in self.agents]
        return shared_agents

    def get_agents(self):
        """Return the list of agents."""
        return self.agents

    def get_acts(self):
        """Return the last act of each agent."""
        return self.acts

    def get_time(self):
        """Return total training time"""
        return self.time.time()

    def get_total_exs(self):
        """Return total amount of examples seen by world."""
        return self.total_exs

    def get_total_epochs(self):
        """Return total amount of epochs on which the world has trained."""
        return self.total_epochs

    def __enter__(self):
        """Empty enter provided for use with ``with`` statement.

        e.g:

        .. code-block:: python

            with World() as world:
                for n in range(10):
                    n.parley()
        """
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        """After ``with`` statement, call shutdown."""
        silent_exit = isinstance(exc_value, KeyboardInterrupt)
        self.shutdown()
        return silent_exit

    def num_examples(self):
        return 0

    def num_episodes(self):
        return 0

    def reset(self):
        for a in self.agents:
            a.reset()
        self.max_exs = None
        self.total_exs = 0
        self.total_epochs = 0
        self.total_parleys = 0
        self.time.reset()

    def reset_metrics(self):
        for a in self.agents:
            a.reset_metrics()

    def save_agents(self):
        """Saves all of the agents in the world by calling their respective
        save() methods.
        """
        for a in self.agents:
            a.save()

    def shutdown(self):
        """Perform any cleanup, if appropriate."""
        pass

    def update_counters(self):
        """Update how many epochs have completed"""
        self.total_parleys += 1
        if self.max_exs is None:
            if ('num_epochs' in self.opt and self.opt['num_epochs'] > 0):
                self.max_exs = self.num_examples() * self.opt['num_epochs'] if self.num_examples() else -1
            else:
                self.max_exs = -1
        # when we know the size of the data
        if self.max_exs > 0:
            self.total_epochs = self.total_parleys * self.opt.get('batchsize', 1) / self.num_examples()
        # when we do not know the size of the data
        else:
            if self.epoch_done():
                self.total_epochs += 1