예제 #1
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True,
                              'compute statistics from model predictions')
    DictionaryAgent.add_cmdline_args(parser)
    # Get command line arguments
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument('-ed',
                        '--external-dict',
                        type=str,
                        default=None,
                        help='External dictionary for stat computation')
    parser.add_argument('-fb',
                        '--freq-bins',
                        type=str,
                        default='0,100,1000,10000',
                        help='Bins boundaries for rare words stat')
    parser.add_argument('-dup',
                        '--dump-predictions-path',
                        type=str,
                        default=None,
                        help='Dump predictions into file')
    parser.add_argument('-cun',
                        '--compute-unique',
                        type=bool,
                        default=True,
                        help='Compute %% of unique responses from the model')
    parser.set_defaults(datatype='valid', model='repeat_label')
    TensorboardLogger.add_cmdline_args(parser)
    return parser
예제 #2
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    parser.add_pytorch_datateacher_args()
    # Get command line arguments
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='If multitasking, average metrics over the '
        'number of examples. If false, averages over the '
        'number of tasks.',
    )
    parser.add_argument(
        '--metrics',
        type=str,
        default='all',
        help='list of metrics to show/compute, e.g. '
        'ppl, f1, accuracy, hits@1.'
        'If `all` is specified [default] all are shown.',
    )
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_defaults(datatype='valid')
    return parser
예제 #3
0
def setup_args(parser=None) -> ParlaiParser:
    """
    Build the ParlAI parser, adding command line args if necessary.
    :param ParlaiParser parser:
        Preexisting parser to append options to. Will be created if needed.
    :returns:
        the ParlaiParser with CLI options added.
    """
    if parser is None:
        parser = ParlaiParser(True, True, 'Train a model')
    train = parser.add_argument_group('FTML Training Loop Arguments')

#     train.add_argument('--n_grad', type='bool', default=False, hidden=True)
#     train.add_argument('-ngrad', '--num-grad', type=int, default=2)
#     train.add_argument('-nadd', '--num-added-data', type=int, default=50)
    train.add_argument('-mbchsztr', '--meta-batchsize_tr', type=int, default=10)
    train.add_argument('-mbchszval', '--meta-batchsize_val', type=int, default=10)
    train.add_argument('-nmmetastep', '--num-meta-steps', type=int, default=50)
    train.add_argument('-nepb', '--num-episode-batch', type=int, default=3)
    train.add_argument('-ebs', '--eval-batch-size', type=int, default=32)
    train.add_argument('-mnts', '--max-num-turns', type=int, default=15)
    
    TensorboardLogger.add_cmdline_args(parser)

    parser = setup_train_args(parser)
    return parser
예제 #4
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    parser.add_pytorch_datateacher_args()
    # Get command line arguments

    # Other command line arguments
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='If multitasking, average metrics over the '
        'number of examples. If false, averages over the '
        'number of tasks.',
    )
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_defaults(datatype='valid')
    return parser
예제 #5
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-dbf',
                       '--dict-build-first',
                       type='bool',
                       default=True,
                       help='build dictionary first before training agent')
    train.add_argument('-eps', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    train.add_argument('-tden',
                       '--train-display-every-n-secs',
                       type=float,
                       default=-1,
                       help='Display training information every n seconds')
    train.add_argument('-vtim',
                       '--validation-every-n-secs',
                       type=float,
                       default=-1,
                       help='Validate every n seconds. Whenever the the best '
                       'validation metric is found, saves the model to '
                       'the model_file path if set.')
    train.add_argument('-stim',
                       '--save-every-n-secs',
                       type=float,
                       default=-1,
                       help='Saves the model to model_file.checkpoint after '
                       'every n seconds (default -1, never).')

    TensorboardLogger.add_cmdline_args(parser)
    parser = setup_dict_args(parser)
    return parser
예제 #6
0
 def __init__(self, opt):
     if isinstance(opt, ParlaiParser):
         print(
             '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
         )
         opt = opt.parse_args()
     # Possibly load from checkpoint
     if opt['load_from_checkpoint'] and opt.get(
             'model_file') and os.path.isfile(opt['model_file'] +
                                              '.checkpoint'):
         opt['init_model'] = opt['model_file'] + '.checkpoint'
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         # If data built via pytorch data teacher, we need to load prebuilt dict
         if opt.get('pytorch_teacher_task'):
             opt['dict_file'] = get_pyt_dict_file(opt)
         elif opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt, skip_if_built=True)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt[
         'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
         else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
         else float('inf')
     self.val_every_n_secs = \
         opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
         else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
         > 0 else float('inf')
     self.val_every_n_epochs = \
         opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
         else float('inf')
     self.last_valid_epoch = 0
     self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
     self.best_valid = None
     if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                 '.best_valid'):
         with open(opt['model_file'] + ".best_valid", 'r') as f:
             x = f.readline()
             self.best_valid = float(x)
             f.close()
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
     if opt['tensorboard_log'] is True:
         self.writer = TensorboardLogger(opt)
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file_transmitter') and opt.get('model_file_receiver'):
                opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt['model_file_receiver']  + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=False)

        # Create model and assign it to the specified task
        print("[ create meta-agent ... ]")
        self.agent = create_agent(opt)
        print("[ create agent A ... ]")
        shared = self.agent.share()
        self.agent_a = create_agent_from_shared(shared)
        self.agent_a.set_id(suffix=' A')
        print("[ create agent B ... ]")
        self.agent_b = create_agent_from_shared(shared)
        # self.agent_b = create_agent(opt)
        self.agent_b.set_id(suffix=' B')
        # self.agent_a.copy(self.agent, 'transmitter')
        # self.agent_b.copy(self.agent, 'transmitter')
        self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b])

        # TODO: if batch, it is also not parallel
        # self.world = BatchSelfPlayWorld(opt, self_play_world)

        self.train_time = Timer()
        self.train_dis_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys_episode = 0
        self.max_num_epochs = opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
        self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt['train_display_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file_transmitter') and os.path.isfile(opt['model_file_transmitter'] + '.best_valid'):
            with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
예제 #8
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                       help=('task to use for valid/test (defaults to the '
                             'one used for training if not set)'))
    train.add_argument('--display-examples', type='bool', default=False)
    train.add_argument('-eps', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                       type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                       type=float, default=2)
    train.add_argument('-vtim', '--validation-every-n-secs',
                       type=float, default=-1,
                       help='Validate every n seconds. Whenever the the best '
                            'validation metric is found, saves the model to '
                            'the model_file path if set.')
    train.add_argument('-stim', '--save-every-n-secs',
                       type=float, default=-1,
                       help='Saves the model to model_file.checkpoint after '
                            'every n seconds (default -1, never).')
    train.add_argument('-sval', '--save-after-valid', type='bool',
                       default=False,
                       help='Saves the model to model_file.checkpoint after '
                            'every validation (default True).')
    train.add_argument('-vme', '--validation-max-exs',
                       type=int, default=-1,
                       help='max examples to use during validation (default '
                            '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                       type=int, default=10,
                       help=('number of iterations of validation where result'
                             ' does not improve before we stop training'))
    train.add_argument('-vmt', '--validation-metric', default='accuracy',
                       help='key into report table for selecting best '
                            'validation')
    train.add_argument('-vmm', '--validation-metric-mode', default='max',
                       type=str, choices=['max', 'min'],
                       help='how to optimize validation metric (max or min)')
    train.add_argument('-vcut', '--validation-cutoff',
                       type=float, default=1.0,
                       help='value at which training will stop if exceeded by '
                            'training metric')
    train.add_argument('-dbf', '--dict-build-first',
                       type='bool', default=True,
                       help='build dictionary first before training agent')
    train.add_argument('-lfc', '--load-from-checkpoint',
                       type='bool', default=False,
                       help='load model from checkpoint if available')
    TensorboardLogger.add_cmdline_args(parser)
    parser = setup_dict_args(parser)
    return parser
예제 #9
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    parser.add_pytorch_datateacher_args()
    # Get command line arguments

    # Probing command line arguments
    parser.add_argument(
        '--probe',
        type=str,
        default=None,
        choices=['word_embeddings', 'encoder_state', 'combined'],
        help="Specify the type of representations to generate for probing. "
        "See 'Probing Neural Dialog for Conversational Understanding' for more details."
    )

    parser.add_argument(
        '-t',
        '--tasks',
        type=str,
        nargs='+',
        required=True,
        help='Usage: -t trecquestion or -t trecquestion wnli multiwoz'
        '\nOnly compatible with names in probing/tasks')
    # Other command line arguments
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='If multitasking, average metrics over the '
        'number of examples. If false, averages over the '
        'number of tasks.',
    )
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_defaults(datatype='valid')
    parser.set_defaults(batchsize=256)
    return parser
예제 #10
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    parser.add_pytorch_datateacher_args()
    # Get command line arguments
    parser.add_argument('-rp',
                        '--report',
                        type=str,
                        default="/tmp/eval_model.json")
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-world-logs',
        type='bool',
        default=False,
        help='Saves a jsonl file containing all of the task examples and '
        'model replies. Must also specify --report-filename.',
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='If multitasking, average metrics over the '
        'number of examples. If false, averages over the '
        'number of tasks.',
    )
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    WorldLogger.add_cmdline_args(parser)
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_defaults(datatype='valid')
    return parser
예제 #11
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    # Get command line arguments
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--world-logs',
        type=str,
        default='',
        help='Saves a jsonl file of the world logs.'
        'Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    WorldLogger.add_cmdline_args(parser, partial_opt=None)
    TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
    parser.set_params(datatype='valid')
    return parser
예제 #12
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    parser.add_pytorch_datateacher_args()
    # Get command line arguments
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument('--metrics', type=str, default="all",
                        help="list of metrics to show/compute, e.g. "
                             "ppl,f1,accuracy,hits@1."
                             "If 'all' is specified [default] all are shown.")
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_defaults(datatype='valid')
    return parser
예제 #13
0
 def __init__(self, optAgent, shared=None):
     init_model, is_finetune = self._get_init_model(optAgent, shared)
     super().__init__(optAgent, shared)
     if optAgent.get('numthreads', 1) > 1:
         torch.set_num_threads(1)
     optAgent['gradient_clip'] = opt.maxgrad
     self.criterion = opt.criterion
     self.loss = opt.loss
     self.drawVars = opt.drawVars
     opt.edim = optAgent['embeddingsize']
     opt.vocabsize = len(self.dict)
     opt.__dict__.update(optAgent)
     opt.agent = self
     opt.fp16 = self.fp16
     torch.manual_seed(args.rank)
     np.random.seed(args.rank)
     self.writeVars = 0
     self.vars = {}
     if optAgent['tensorboard_log']:
         self.writeVars, *_ = getWriter(writer=TensorboardLogger(optAgent))
     if self.fp16:
         try:
             from apex import amp
         except ImportError:
             raise ImportError(
                 'No fp16 support without apex. Please install it from '
                 'https://github.com/NVIDIA/apex')
         self.getParameters = lambda: amp.master_params(self.optimizer)
         self.amp = amp
     else:
         self.getParameters = lambda: self.model.parameters()
     if not shared:
         model = Model(opt)
         self.model = model
         if init_model:
             print('Loading existing model parameters from ' + init_model)
             states = self.load(init_model)
         else:
             states = {}
             initParameters(opt, self.model)
         if self.use_cuda:
             self.model.cuda()
         self.model.train()
         if optAgent.get('numthreads', 1) > 1:
             self.model.share_memory()
         paramOptions = getParamOptions(opt, self.model)
         self.init_optim(paramOptions, states.get('optimizer'),
                         states.get('saved_optim_type', None))
         self.build_lr_scheduler(states, hard_reset=is_finetune)
         if is_distributed():
             self.model = nn.parallel.DistributedDataParallel(
                 self.model,
                 device_ids=[self.opt['gpu']],
                 broadcast_buffers=False)
         self.reset()
     else:
         self.model = shared['model']
         self.dict = shared['dict']
         if 'optimizer' in shared:
             self.optimizer = shared['optimizer']
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True,
                              'compute statistics from model predictions')
    DictionaryAgent.add_cmdline_args(parser)

    # These defaults can be overriden by both .opt file and user's command line flags
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument(
        '-ed',
        '--external-dict',
        type=str,
        default=None,
        help='External dictionary for stat computation',
    )
    parser.add_argument(
        '-fb',
        '--freq-bins',
        type=str,
        default='0,100,1000,10000',
        help='Bins boundaries for rare words stat',
    )
    parser.add_argument(
        '-gr',
        '--gold-response',
        type=bool,
        default=False,
        help='Compute stats for gold response',
    )

    # These settings override .opt file but not user's command line flags
    parser.set_params(
        datatype='valid',
        task='projects.controllable_dialogue.tasks.agents',
        model=
        'projects.controllable_dialogue.controllable_seq2seq.controllable_seq2seq:ControllableSeq2seqAgent',  # noqa: E501
        batchsize=64,
        beam_size=20,
        beam_min_n_best=10,
        use_reply='model',
    )
    TensorboardLogger.add_cmdline_args(parser)
    return parser
예제 #15
0
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

        # default one does not average
        self.rank_loss = torch.nn.CrossEntropyLoss(reduce=True,
                                                   size_average=True)
        torch.autograd.set_detect_anomaly(True)
        torch.manual_seed(123)
예제 #16
0
def setup_args(parser=None) -> ParlaiParser:
    """
    Build the ParlAI parser, adding command line args if necessary.
    :param ParlaiParser parser:
        Preexisting parser to append options to. Will be created if needed.
    :returns:
        the ParlaiParser with CLI options added.
    """
    if parser is None:
        parser = ParlaiParser(True, True, 'Train a model')
    train = parser.add_argument_group('FTML Training Loop Arguments')

    #     train.add_argument('--n_grad', type='bool', default=False, hidden=True)
    train.add_argument('-nomt', '--no-multi-task', type='bool', default=False)
    train.add_argument('-nepb', '--num-episode-batch', type=int, default=3)
    train.add_argument('-ebs', '--eval-batch-size', type=int, default=32)

    TensorboardLogger.add_cmdline_args(parser)

    parser = setup_train_args(parser)
    return parser
    def __init__(self, opt, shared=None):

        super().__init__(opt, shared)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

        self.dictionnary_size = 177
        self.embedding_dim = 100
        self.batch_size = opt["batchsize"]

        self.criterion = nn.CrossEntropyLoss()

        def weight_init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)

        self.recurrent_entity_network = RecurrentEntityNetwork(
            self.dictionnary_size, self.embedding_dim, sequence_length=7)
        self.recurrent_entity_network.apply(weight_init)
        self.optimizer = optim.Adam(self.recurrent_entity_network.parameters())
        #self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 25, 0.5)
        self.batch_iter = 0
예제 #18
0
class TrainLoop:
    """
    TrainLoop contains the core training loop logic.
    """

    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (
            opt['load_from_checkpoint']
            and opt.get('model_file')
            and PathManager.exists(opt['model_file'] + '.checkpoint')
        ):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.'
            )
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            logging.info("building dictionary first...")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.agent.opt.log()
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()

        self.parleys = 0
        self._train_steps = 0
        self._last_log_steps = 0
        self.update_freq = opt.get('update_freq', 1)

        self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True)
        self.max_train_time = _num_else_inf(
            opt, 'max_train_time', distributed_warn=True
        )
        self.max_train_steps = _num_else_inf(opt, 'max_train_steps')
        self.log_every_n_secs = _num_else_inf(
            opt, 'log_every_n_secs', distributed_warn=True
        )
        self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps')
        self.val_every_n_secs = _num_else_inf(
            opt, 'validation_every_n_secs', distributed_warn=True
        )
        self.val_every_n_epochs = _num_else_inf(
            opt, 'validation_every_n_epochs', distributed_warn=True
        )
        self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps')
        self.save_every_n_secs = _num_else_inf(
            opt, 'save_every_n_secs', distributed_warn=True
        )

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self._last_valid_steps = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.train_reports = []
        self.valid_reports = []
        self.final_valid_report = {}
        self.final_test_report = {}
        self.final_extra_valid_report = {}
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and PathManager.exists(
            opt['model_file'] + trainstats_suffix
        ):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self._train_steps = obj.get('train_steps', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if self.valid_reports:
                    self.last_valid_epoch = self.valid_reports[-1].get(
                        'total_epochs', 0.0
                    )
                self.train_reports = obj.get('train_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and PathManager.exists(
                        opt['model_file'] + '.best_valid'
                    ):
                        with PathManager.open(
                            opt['model_file'] + ".best_valid", 'r'
                        ) as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)
        if opt['wandb_log'] and is_primary_worker():
            model = self.agent.model if hasattr(self.agent, 'model') else None
            self.wb_logger = WandbLogger(opt, model)

    def save_model(self, suffix=None):
        """
        Save the model to disk, possibly with a suffix.
        """
        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix

        if not is_primary_worker():
            # never do IO as a non-primary worker
            if hasattr(self.agent, 'save_nonprimary'):
                self.agent.save_nonprimary(fn)
            return

        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _save_train_stats(self, suffix=None):
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return
        fn = self.opt.get('model_file', None)
        if not fn:
            return
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with PathManager.open(fn, 'w') as f:
            json.dump(
                {
                    'parleys': self.parleys,
                    'train_time': self.train_time.time(),
                    'train_steps': self._train_steps,
                    'total_epochs': self._total_epochs,
                    'train_reports': self.train_reports,
                    'valid_reports': self.valid_reports,
                    'best_valid': self.best_valid,
                    'impatience': self.impatience,
                    'final_valid_report': dict_report(self.final_valid_report),
                    'final_test_report': dict_report(self.final_test_report),
                    'final_extra_valid_report': dict_report(
                        self.final_extra_valid_report
                    ),
                },
                f,
                indent=4,
            )

    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = self._run_eval(
            self.valid_worlds, opt, 'valid', opt['validation_max_exs']
        )
        v = dict_report(valid_report)
        v['train_time'] = self.train_time.time()
        v['parleys'] = self.parleys
        v['train_steps'] = self._train_steps
        v['total_exs'] = self._total_exs
        v['total_epochs'] = self._total_epochs
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
            # flush on a validation
            self.tb_logger.flush()
        if opt['wandb_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.wb_logger.log_metrics('valid', self.parleys, valid_report)

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        new_valid = valid_report[opt['validation_metric']]

        if isinstance(new_valid, Metric):
            new_valid = new_valid.value()

        # check if this is the best validation so far
        if (
            self.best_valid is None
            or self.valid_optim * new_valid > self.valid_optim * self.best_valid
        ):
            logging.success(
                'new best {}: {:.4g}{}'.format(
                    opt['validation_metric'],
                    new_valid,
                    ' (previous best was {:.4g})'.format(self.best_valid)
                    if self.best_valid is not None
                    else '',
                )
            )
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                logging.info(f"saving best valid model: {opt['model_file']}")
                self.save_model()
                self.saved = True
            if (
                opt['validation_metric_mode'] == 'max'
                and self.best_valid >= opt['validation_cutoff']
            ) or (
                opt['validation_metric_mode'] == 'min'
                and self.best_valid <= opt['validation_cutoff']
            ):
                logging.info('task solved! stopping.')
                return True
        else:
            self.impatience += 1
            logging.report(
                'did not beat best {}: {} impatience: {}'.format(
                    opt['validation_metric'], round(self.best_valid, 4), self.impatience
                )
            )
        self.validate_time.reset()

        # saving
        if opt.get('model_file') and opt.get('save_after_valid'):
            logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint")
            self.save_model('.checkpoint')

        # check if we are out of patience
        if (
            opt['validation_patience'] > 0
            and self.impatience >= opt['validation_patience']
        ):
            logging.info('ran out of patience! stopping training.')
            return True
        return False

    def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

        # run evaluation on a single world
        valid_world.reset()

        world_logger = None
        task_opt = opt.copy()
        # set up world logger for the "test" fold
        if opt['world_logs'] and datatype == 'test':
            task_opt['world_logs'] = get_task_world_logs(
                task, opt['world_logs'], is_multitask
            )
            world_logger = WorldLogger(task_opt)

        cnt = 0
        max_cnt = max_exs if max_exs > 0 else float('inf')
        while not valid_world.epoch_done() and cnt < max_cnt:
            valid_world.parley()
            if world_logger is not None:
                world_logger.log(valid_world)
            if cnt == 0 and opt['display_examples']:
                print(valid_world.display() + '\n~~')
                print(valid_world.report())
            cnt = valid_world.report().get('exs') or 0

        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            if is_distributed():
                rank = get_rank()
                base_outfile, extension = os.path.splitext(task_opt['world_logs'])
                outfile = base_outfile + f'_{rank}' + extension
            else:
                outfile = task_opt['world_logs']
            world_logger.write(outfile, valid_world, file_format=opt['save_format'])

        valid_report = valid_world.report()
        if opt.get('validation_share_agent', False):
            valid_world.reset()  # make sure world doesn't remember valid data

        return valid_report

    def _run_eval(
        self,
        valid_worlds,
        opt,
        datatype,
        max_exs=-1,
        write_log=False,
        extra_log_suffix="",
    ):
        """
        Eval on validation/test data.

        :param valid_world:
            list of the pre-created validation worlds.
        :param opt:
            the options that specific the task, eval_task, etc
        :param datatype:
            the datatype to use, such as "valid" or "test"
        :param bool write_log:
            specifies to write metrics to file if the model_file is set
        :param int max_exs:
            limits the number of examples if max_exs > 0
        """

        logging.info(f'running eval: {datatype}')
        timer = Timer()
        reports = []

        max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
        is_multitask = len(valid_worlds) > 1
        for index, v_world in enumerate(valid_worlds):
            if opt.get('evaltask'):
                task = opt['evaltask'].split(',')[index]
            else:
                task = opt['task'].split(',')[index]
            task_report = self._run_single_eval(
                opt, v_world, max_exs_per_worker, datatype, is_multitask, task
            )
            reports.append(task_report)

        tasks = [world.getID() for world in valid_worlds]
        named_reports = dict(zip(tasks, reports))
        report = aggregate_named_reports(
            named_reports, micro_average=self.opt.get('aggregate_micro', False)
        )
        # get the results from all workers
        report = self._sync_metrics(report)

        metrics = f'{datatype}:\n{nice_report(report)}\n'
        logging.info(f'eval completed in {timer.time():.2f}s')
        logging.report(metrics)

        # write to file
        if write_log and opt.get('model_file') and is_primary_worker():
            # Write out metrics
            with PathManager.open(
                opt['model_file'] + extra_log_suffix + '.' + datatype, 'a'
            ) as f:
                f.write(f'{metrics}\n')

        return report

    def _run_final_extra_eval(self, opt):
        final_valid_opt = copy.deepcopy(opt)
        final_valid_opt_raw = Opt.load_init(opt['final_extra_opt'])
        final_datatype = final_valid_opt_raw["datatype"]
        for k, v in final_valid_opt_raw.items():
            final_valid_opt[k] = v
        final_max_exs = (
            final_valid_opt['validation_max_exs']
            if final_valid_opt.get('short_final_eval')
            else -1
        )
        final_valid_world = load_eval_worlds(
            self.agent, final_valid_opt, final_datatype
        )
        final_valid_report = self._run_eval(
            final_valid_world,
            final_valid_opt,
            final_datatype,
            final_max_exs,
            write_log=True,
            extra_log_suffix="_extra",
        )
        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_final(final_datatype, final_valid_report)

        return final_valid_report

    def _sync_metrics(self, metrics):
        """
        Sync training metrics across workers.

        A handful of special cases are handled as exceptions, and the remaining metrics
        are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return aggregate_unnamed_reports(all_versions)

    def _compute_eta(
        self, epochs_completed: float, time_elapsed: float, steps_taken: int
    ):
        """
        Compute the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        max_train_steps = self.opt.get('max_train_steps', -1)
        if max_train_steps > 0 and steps_taken > 0:
            steps_progress = steps_taken / max_train_steps
            eta = (1 - steps_progress) * time_elapsed / steps_progress

        return eta

    def _get_time(self, world: World) -> Tuple[float, float, float]:
        """
        Return train, log, and validate timing.

        If relying on the time for validation/logging/max train time purposes,
        we sync and return primary worker's time.

        Otherwise, it's not super relevant what we do here.

        **SIDE EFFECT**: Update _total_epochs trained.

        :param world:
            current running world

        :return (train, log, valid):
            return time for each of train, log, and validation
        """
        if (
            self.max_train_time < float('inf')
            or self.log_every_n_secs < float('inf')
            or self.val_every_n_secs < float('inf')
            or self.val_every_n_epochs < float('inf')
            or self.max_num_epochs < float('inf')
        ):
            self._total_epochs = self._preempted_epochs + sum(
                all_gather_list(world.get_total_epochs())
            )
            train_time, log_time, validate_time, save_time = sync_object(
                (
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                    self.save_time.time(),
                )
            )
        else:
            train_time, log_time, validate_time, save_time = (
                self.train_time.time(),
                self.log_time.time(),
                self.validate_time.time(),
                self.save_time.time(),
            )
            self._total_epochs = self._preempted_epochs + (
                num_workers() * world.get_total_epochs()
            )

        return train_time, log_time, validate_time, save_time

    def log(self):
        """
        Output a training log entry.
        """
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        train_report = self._sync_metrics(train_report)
        self.world.reset_metrics()

        train_report_trainstats = dict_report(train_report)
        train_report_trainstats['total_epochs'] = self._total_epochs
        train_report_trainstats['total_exs'] = self._total_exs
        train_report_trainstats['parleys'] = self.parleys
        train_report_trainstats['train_steps'] = self._train_steps
        train_report_trainstats['train_time'] = self.train_time.time()
        self.train_reports.append(train_report_trainstats)

        # time elapsed
        logs.append(f'time:{self.train_time.time():.0f}s')
        logs.append(f'total_exs:{self._total_exs}')
        logs.append(f'total_steps:{self._train_steps}')

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append(f'epochs:{self._total_epochs:.2f}')

        time_left = self._compute_eta(
            self._total_epochs, self.train_time.time(), self._train_steps
        )
        if time_left is not None:
            logs.append(f'time_left:{max(0,time_left):.0f}s')

        log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report))
        logging.info(log)
        self.log_time.reset()
        self._last_log_steps = 0

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)
        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_metrics('train', self.parleys, train_report)

        return train_report

    def train_steps(self):
        """
        Core training loop.

        Yields a metrics dict with each log.
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException as e:
                    logging.info(f"Stopping from {e}")
                    break

                self.parleys += 1
                self._train_steps = self.parleys // self.update_freq
                self._last_log_steps += 1 / self.update_freq

                # the following additionally updates self._total_epochs
                train_time, log_time, validate_time, save_time = self._get_time(world)
                # get the total training examples done, compute epochs
                exs_per_epoch = world.num_examples()
                self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    yield self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if self._train_steps >= self.max_train_steps:
                    logging.info(
                        f'max_train_steps elapsed:{self._train_steps} '
                        f'time elapsed:{train_time}s'
                    )
                    break
                if (
                    log_time > self.log_every_n_secs
                    or self._last_log_steps >= self.log_every_n_steps
                ):
                    yield self.log()
                if (
                    validate_time > self.val_every_n_secs
                    or self._total_epochs - self.last_valid_epoch
                    >= self.val_every_n_epochs
                    or self._train_steps - self._last_valid_steps
                    >= self.val_every_n_steps
                ):
                    try:
                        # log before we validate
                        if self._last_log_steps:
                            yield self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    self._last_valid_steps = self._train_steps
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if save_time > self.save_every_n_secs and opt.get('model_file'):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not sync_object(self.saved):
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)

    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        for _train_log in self.train_steps():
            # we've already done what we need in these
            pass

        # perform final validation/testing
        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
        self.final_valid_report = self._run_eval(
            valid_worlds, opt, 'valid', max_exs, write_log=True
        )
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        self.final_test_report = self._run_eval(
            test_worlds, opt, 'test', max_exs, write_log=True
        )

        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_final('valid', self.final_valid_report)
            self.wb_logger.log_final('test', self.final_test_report)
            self.wb_logger.finish()

        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        if opt['final_extra_opt'] != '':
            self.final_extra_valid_report = self._run_final_extra_eval(opt)

        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.finish()

        self._save_train_stats()

        return self.final_valid_report, self.final_test_report
예제 #19
0
파일: train_model.py 프로젝트: ying-A/RED
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
예제 #20
0
파일: train_model.py 프로젝트: ying-A/RED
class TrainLoop():
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

    def save_model(self, suffix=None):
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return
        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _save_train_stats(self, suffix=None):
        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with open(fn, 'w') as f:
            json.dump(
                {
                    'train_time':
                    self.train_time.time(),
                    'total_epochs':
                    (self._preempted_epochs +
                     num_workers() * self.world.get_total_epochs()),
                    'impatience':
                    self.impatience,
                }, f)

    def validate(self):
        opt = self.opt

        if self.valid_world is None:
            # we need to load the world now
            self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = sync_object(
            run_eval(self.valid_world, opt, 'valid', opt['validation_max_exs'],
                     True))

        # logging
        if opt['tensorboard_log'] is True and is_primary_worker():
            self.writer.add_metrics('valid', int(self.train_time.time()),
                                    valid_report)
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def _average_dicts(self, all_versions):
        # instead of a list-of-dicts with like keys, make a dict-of-lists with
        # keys to reduce
        to_reduce = {}
        for d in all_versions:
            for k, v in d.items():
                to_reduce.setdefault(k, []).append(v)
        # now perform the reduction
        finalized = {}
        for k, values in to_reduce.items():
            if k == 'exs' or k == 'total_skipped_batches':
                # sum across workers
                finalized[k] = np.sum(values)
            elif isinstance(values[0], dict):
                # do the same procedure recursively
                finalized[k] = self._average_dicts(values)
            else:
                # all other cases, take the mean across the workers
                finalized[k] = np.mean(values)
        return finalized

    def _sync_training_metrics(self, metrics):
        """
        Sync training metrics across workers. A handful of special cases are handled
        as exceptions, and the remaining metrics are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return self._average_dicts(all_versions)

    def _nice_format(self, dictionary):
        rounded = {}
        for k, v in dictionary.items():
            if isinstance(v, dict):
                rounded[k] = self._nice_format(v)
            elif isinstance(v, float):
                rounded[k] = round_sigfigs(v, 4)
            else:
                rounded[k] = v
        return rounded

    def _compute_eta(self, epochs_completed, time_elapsed):
        """
        Computes the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        return eta

    def log(self):
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self._sync_training_metrics(self.world.report())
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(np.floor(self.train_time.time())))
        logs.append('total_exs:{}'.format(self._total_exs))

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append('epochs:{}'.format(round(self._total_epochs, 2)))

        time_left = self._compute_eta(self._total_epochs,
                                      self.train_time.time())
        if time_left is not None:
            logs.append('time_left:{}s'.format(max(0, np.ceil(time_left))))

        log = '[ {} ] {}'.format(' '.join(logs),
                                 self._nice_format(train_report))
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] is True and is_primary_worker():
            self.writer.add_metrics('train', self._total_exs, train_report)

    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1
                # print(world.display())

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object(
                    (self.train_time.time(), self.log_time.time(),
                     self.validate_time.time()))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        v_report = run_eval(valid_world, opt, 'valid', write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        return v_report, t_report
예제 #21
0
def setup_args(parser=None) -> ParlaiParser:
    """
    Build the ParlAI parser, adding command line args if necessary.

    :param ParlaiParser parser:
        Preexisting parser to append options to. Will be created if needed.

    :returns:
        the ParlaiParser with CLI options added.
    """
    if parser is None:
        parser = ParlaiParser(True, True, 'Train a model')
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument(
        '-et',
        '--evaltask',
        help=
        'task to use for valid/test (defaults to the one used for training)',
    )
    train.add_argument(
        '--eval-batchsize',
        type=int,
        hidden=True,
        help='Eval time batch size (defaults to same as -bs)',
    )
    train.add_argument('--display-examples',
                       type='bool',
                       default=False,
                       hidden=True)
    train.add_argument('-eps', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
    train.add_argument(
        '-vtim',
        '--validation-every-n-secs',
        type=float,
        default=-1,
        help='Validate every n seconds. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-stim',
        '--save-every-n-secs',
        type=float,
        default=-1,
        help='Saves the model to model_file.checkpoint after '
        'every n seconds (default -1, never).',
    )
    train.add_argument(
        '-sval',
        '--save-after-valid',
        type='bool',
        default=False,
        help='Saves the model to model_file.checkpoint after '
        'every validation (default %(default)s).',
    )
    train.add_argument(
        '-veps',
        '--validation-every-n-epochs',
        type=float,
        default=-1,
        help='Validate every n epochs. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-vme',
        '--validation-max-exs',
        type=int,
        default=-1,
        hidden=True,
        help='max examples to use during validation (default -1 uses all)',
    )
    train.add_argument(
        '--short-final-eval',
        default=False,
        hidden=True,
        type='bool',
        help='If true, obeys --validation-max-exs in the final '
        'validation and test evaluations.',
    )
    train.add_argument(
        '-vp',
        '--validation-patience',
        type=int,
        default=10,
        help=('number of iterations of validation where result'
              ' does not improve before we stop training'),
    )
    train.add_argument(
        '-vmt',
        '--validation-metric',
        default='accuracy',
        help='key into report table for selecting best validation',
    )
    train.add_argument(
        '-vmm',
        '--validation-metric-mode',
        type=str,
        choices=['max', 'min'],
        help='how to optimize validation metric (max or min)',
    )
    train.add_argument(
        '-vcut',
        '--validation-cutoff',
        type=float,
        default=1.0,
        hidden=True,
        help='value at which training will stop if exceeded by metric',
    )
    train.add_argument(
        '-lfc',
        '--load-from-checkpoint',
        type='bool',
        default=True,
        hidden=True,
        help='load model from checkpoint if available',
    )
    train.add_argument(
        '-vshare',
        '--validation-share-agent',
        default=False,
        hidden=True,
        help='use a shared copy of the agent for validation. '
        'this will eventually default to True, but '
        'currently defaults to False.',
    )
    train.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    train.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    TensorboardLogger.add_cmdline_args(parser)

    parser = setup_dict_args(parser)
    return parser
예제 #22
0
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        super().__init__(opt, shared)
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.is_combine_attr = (hasattr(self, 'other_task_datafiles')
                                and self.other_task_datafiles)
        self.random_policy = opt.get('random_policy', False)
        self.count_sample = opt.get('count_sample', False)
        self.anti = opt.get('anti', False)

        if self.random_policy:
            random.seed(17)

        if not shared:
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                score_list = [episode[0][2] for episode in self.data.data]
                assert score_list == sorted(score_list)
                num_buckets = opt.get('num_buckets',
                                      int(self.num_episodes() / 10))
                lb_indices = [
                    int(len(score_list) * i / num_buckets)
                    for i in range(num_buckets)
                ]
                lbs = [score_list[idx] for idx in lb_indices]
                bucket_ids = [
                    self.sort_into_bucket(ctrl_val, lbs)
                    for ctrl_val in score_list
                ]
                bucket_cnt = [0 for _ in range(num_buckets)]
                for i in range(num_buckets):
                    bucket_cnt[i] = bucket_ids.count(i)
                self.bucket_cnt = bucket_cnt
            self.lastYs = [None] * self.bsz
            # build multiple task data
            self.tasks = [self.data]

            if self.is_combine_attr:
                print('[ build multiple task data ... ]')
                for datafile in self.other_task_datafiles:
                    task_opt = copy.deepcopy(opt)
                    task_opt['datafile'] = datafile
                    self.tasks.append(
                        DialogData(task_opt,
                                   data_loader=self.setup_data,
                                   cands=self.label_candidates()))
                print('[ build multiple task data done! ]')

                # record the selections of each subtasks
                self.subtasks = opt['subtasks'].split(':')
                self.subtask_counter = OrderedDict()
                self.p_selections = OrderedDict()
                self.c_selections = OrderedDict()
                for t in self.subtasks:
                    self.subtask_counter[t] = 0
                    self.p_selections[t] = []
                    self.c_selections[t] = []

                if self.count_sample and not self.stream:
                    self.sample_counter = OrderedDict()
                    for idx, t in enumerate(self.subtasks):
                        self.sample_counter[t] = [
                            0 for _ in self.tasks[idx].data
                        ]

            # setup the tensorboard log
            if opt['tensorboard_log_teacher'] is True:
                opt['tensorboard_tag'] = 'task'
                teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split(
                    ',')
                opt['tensorboard_metrics'] = ','.join(
                    opt['tensorboard_metrics'].split(',') + teacher_metrics)
                self.writer = TensorboardLogger(opt)

        else:
            self.lastYs = shared['lastYs']
            self.tasks = shared['tasks']
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                self.bucket_cnt = shared['bucket_cnt']
            if 'writer' in shared:
                self.writer = shared['writer']
            if 'subtask_counter' in shared:
                self.subtask_counter = shared['subtask_counter']
            if 'p_selections' in shared:
                self.p_selections = shared['p_selections']
            if 'c_selections' in shared:
                self.c_selections = shared['c_selections']

        # build the policy net, criterion and optimizer here
        self.state_dim = 32 + len(self.tasks)  # hand-craft features
        self.action_dim = len(self.tasks)

        if not shared:
            self.policy = PolicyNet(self.state_dim, self.action_dim)
            self.critic = CriticNet(self.state_dim, self.action_dim)

            init_teacher = get_init_teacher(opt, shared)
            if init_teacher is not None:
                # load teacher parameters if available
                print('[ Loading existing teacher params from {} ]'
                      ''.format(init_teacher))
                states = self.load(init_teacher)
            else:
                states = {}
        else:
            self.policy = shared['policy']
            self.critic = shared['critic']
            states = shared['states']

        if (
                # only build an optimizer if we're training
                'train' in opt.get('datatype', '') and
                # and this is the main model
                shared is None):
            # for policy net
            self.optimizer = self.init_optim(
                [p for p in self.policy.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher'],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler' in states:
                self.scheduler.load_state_dict(states['lr_scheduler'])

            # for critic net
            self.optimizer_critic = self.init_optim(
                [p for p in self.critic.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher_critic'],
                optim_states=states.get('optimizer_critic'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_critic,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler_critic' in states:
                self.scheduler_critic.load_state_dict(
                    states['lr_scheduler_critic'])

            self.critic_criterion = torch.nn.SmoothL1Loss()

        self.reward_metric = opt.get('reward_metric', 'total_metric')
        self.reward_metric_mode = opt.get('reward_metric_mode', 'max')

        self.prev_prev_valid_report = states[
            'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None
        self.prev_valid_report = states[
            'prev_valid_report'] if 'prev_valid_report' in states else None
        self.current_valid_report = states[
            'current_valid_report'] if 'current_valid_report' in states else None
        self.saved_actions = states[
            'saved_actions'] if 'saved_actions' in states else OrderedDict()
        self.saved_state_actions = states[
            'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict(
            )
        if self.use_cuda:
            for k, v in self.saved_actions.items():
                self.saved_actions[k] = v.cuda()
            for k, v in self.saved_state_actions.items():
                self.saved_state_actions[k] = v.cuda()
        self._number_teacher_updates = states[
            '_number_teacher_updates'] if '_number_teacher_updates' in states else 0

        # enable the batch_act
        self.use_batch_act = self.bsz > 1

        self.T = self.opt.get('T', 1000)
        self.c0 = self.opt.get('c0', 0.01)
        self.p = self.opt.get('p', 2)

        # setup the timer
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.action_log_time = Timer()

        self.move_to_cuda()
예제 #23
0
class DefaultTeacher(FbDialogTeacher):
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        super().__init__(opt, shared)
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.is_combine_attr = (hasattr(self, 'other_task_datafiles')
                                and self.other_task_datafiles)
        self.random_policy = opt.get('random_policy', False)
        self.count_sample = opt.get('count_sample', False)
        self.anti = opt.get('anti', False)

        if self.random_policy:
            random.seed(17)

        if not shared:
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                score_list = [episode[0][2] for episode in self.data.data]
                assert score_list == sorted(score_list)
                num_buckets = opt.get('num_buckets',
                                      int(self.num_episodes() / 10))
                lb_indices = [
                    int(len(score_list) * i / num_buckets)
                    for i in range(num_buckets)
                ]
                lbs = [score_list[idx] for idx in lb_indices]
                bucket_ids = [
                    self.sort_into_bucket(ctrl_val, lbs)
                    for ctrl_val in score_list
                ]
                bucket_cnt = [0 for _ in range(num_buckets)]
                for i in range(num_buckets):
                    bucket_cnt[i] = bucket_ids.count(i)
                self.bucket_cnt = bucket_cnt
            self.lastYs = [None] * self.bsz
            # build multiple task data
            self.tasks = [self.data]

            if self.is_combine_attr:
                print('[ build multiple task data ... ]')
                for datafile in self.other_task_datafiles:
                    task_opt = copy.deepcopy(opt)
                    task_opt['datafile'] = datafile
                    self.tasks.append(
                        DialogData(task_opt,
                                   data_loader=self.setup_data,
                                   cands=self.label_candidates()))
                print('[ build multiple task data done! ]')

                # record the selections of each subtasks
                self.subtasks = opt['subtasks'].split(':')
                self.subtask_counter = OrderedDict()
                self.p_selections = OrderedDict()
                self.c_selections = OrderedDict()
                for t in self.subtasks:
                    self.subtask_counter[t] = 0
                    self.p_selections[t] = []
                    self.c_selections[t] = []

                if self.count_sample and not self.stream:
                    self.sample_counter = OrderedDict()
                    for idx, t in enumerate(self.subtasks):
                        self.sample_counter[t] = [
                            0 for _ in self.tasks[idx].data
                        ]

            # setup the tensorboard log
            if opt['tensorboard_log_teacher'] is True:
                opt['tensorboard_tag'] = 'task'
                teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split(
                    ',')
                opt['tensorboard_metrics'] = ','.join(
                    opt['tensorboard_metrics'].split(',') + teacher_metrics)
                self.writer = TensorboardLogger(opt)

        else:
            self.lastYs = shared['lastYs']
            self.tasks = shared['tasks']
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                self.bucket_cnt = shared['bucket_cnt']
            if 'writer' in shared:
                self.writer = shared['writer']
            if 'subtask_counter' in shared:
                self.subtask_counter = shared['subtask_counter']
            if 'p_selections' in shared:
                self.p_selections = shared['p_selections']
            if 'c_selections' in shared:
                self.c_selections = shared['c_selections']

        # build the policy net, criterion and optimizer here
        self.state_dim = 32 + len(self.tasks)  # hand-craft features
        self.action_dim = len(self.tasks)

        if not shared:
            self.policy = PolicyNet(self.state_dim, self.action_dim)
            self.critic = CriticNet(self.state_dim, self.action_dim)

            init_teacher = get_init_teacher(opt, shared)
            if init_teacher is not None:
                # load teacher parameters if available
                print('[ Loading existing teacher params from {} ]'
                      ''.format(init_teacher))
                states = self.load(init_teacher)
            else:
                states = {}
        else:
            self.policy = shared['policy']
            self.critic = shared['critic']
            states = shared['states']

        if (
                # only build an optimizer if we're training
                'train' in opt.get('datatype', '') and
                # and this is the main model
                shared is None):
            # for policy net
            self.optimizer = self.init_optim(
                [p for p in self.policy.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher'],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler' in states:
                self.scheduler.load_state_dict(states['lr_scheduler'])

            # for critic net
            self.optimizer_critic = self.init_optim(
                [p for p in self.critic.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher_critic'],
                optim_states=states.get('optimizer_critic'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_critic,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler_critic' in states:
                self.scheduler_critic.load_state_dict(
                    states['lr_scheduler_critic'])

            self.critic_criterion = torch.nn.SmoothL1Loss()

        self.reward_metric = opt.get('reward_metric', 'total_metric')
        self.reward_metric_mode = opt.get('reward_metric_mode', 'max')

        self.prev_prev_valid_report = states[
            'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None
        self.prev_valid_report = states[
            'prev_valid_report'] if 'prev_valid_report' in states else None
        self.current_valid_report = states[
            'current_valid_report'] if 'current_valid_report' in states else None
        self.saved_actions = states[
            'saved_actions'] if 'saved_actions' in states else OrderedDict()
        self.saved_state_actions = states[
            'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict(
            )
        if self.use_cuda:
            for k, v in self.saved_actions.items():
                self.saved_actions[k] = v.cuda()
            for k, v in self.saved_state_actions.items():
                self.saved_state_actions[k] = v.cuda()
        self._number_teacher_updates = states[
            '_number_teacher_updates'] if '_number_teacher_updates' in states else 0

        # enable the batch_act
        self.use_batch_act = self.bsz > 1

        self.T = self.opt.get('T', 1000)
        self.c0 = self.opt.get('c0', 0.01)
        self.p = self.opt.get('p', 2)

        # setup the timer
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.action_log_time = Timer()

        self.move_to_cuda()

    def move_to_cuda(self):
        if self.use_cuda:
            self.policy.cuda()
            self.critic.cuda()

    @classmethod
    def optim_opts(self):
        """
        Fetch optimizer selection.

        By default, collects everything in torch.optim, as well as importing:
        - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim

        Override this (and probably call super()) to add your own optimizers.
        """
        # first pull torch.optim in
        optims = {
            k.lower(): v
            for k, v in optim.__dict__.items()
            if not k.startswith('__') and k[0].isupper()
        }
        try:
            import apex.optimizers.fused_adam as fused_adam
            optims['fused_adam'] = fused_adam.FusedAdam
        except ImportError:
            pass

        try:
            # https://openreview.net/pdf?id=S1fUpoR5FQ
            from qhoptim.pyt import QHM, QHAdam
            optims['qhm'] = QHM
            optims['qhadam'] = QHAdam
        except ImportError:
            # no QHM installed
            pass

        return optims

    def init_optim(self, params, lr, optim_states=None, saved_optim_type=None):
        """
        Initialize optimizer with teacher parameters.

        :param params:
            parameters from the teacher

        :param optim_states:
            optional argument providing states of optimizer to load

        :param saved_optim_type:
            type of optimizer being loaded, if changed will skip loading
            optimizer states
        """

        opt = self.opt

        # set up optimizer args
        kwargs = {'lr': lr}
        if opt.get('momentum_teacher') > 0 and opt['optimizer_teacher'] in [
                'sgd', 'rmsprop', 'qhm'
        ]:
            # turn on momentum for optimizers that use it
            kwargs['momentum'] = opt['momentum_teacher']
            if opt['optimizer_teacher'] == 'sgd' and opt.get(
                    'nesterov_teacher', True):
                # for sgd, maybe nesterov
                kwargs['nesterov'] = opt.get('nesterov_teacher', True)
            elif opt['optimizer_teacher'] == 'qhm':
                # qhm needs a nu
                kwargs['nu'] = opt.get('nus_teacher', (0.7, ))[0]
        elif opt['optimizer_teacher'] == 'adam':
            # turn on amsgrad for adam
            # amsgrad paper: https://openreview.net/forum?id=ryQu7f-RZ
            kwargs['amsgrad'] = True
        elif opt['optimizer_teacher'] == 'qhadam':
            # set nus for qhadam
            kwargs['nus'] = opt.get('nus_teacher', (0.7, 1.0))
        if opt['optimizer_teacher'] in [
                'adam', 'sparseadam', 'adamax', 'qhadam'
        ]:
            # set betas for optims that use it
            kwargs['betas'] = opt.get('betas_teacher', (0.9, 0.999))

        optim_class = self.optim_opts()[opt['optimizer_teacher']]
        optimizer = optim_class(params, **kwargs)

        if optim_states:
            if saved_optim_type != opt['optimizer_teacher']:
                print('WARNING: not loading optim state since optim class '
                      'changed.')
            else:
                try:
                    optimizer.load_state_dict(optim_states)
                except ValueError:
                    print('WARNING: not loading optim state since model '
                          'params changed.')
                if self.use_cuda:
                    for state in optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
        return optimizer

    def load(self, path):
        """
        Return opt and teacher states.

        TODO: load behaviors should be consistent with function state_dict().
        """
        states = torch.load(path, map_location=lambda cpu, _: cpu)
        if 'policy' in states:
            self.policy.load_state_dict(states['policy'])
        if 'critic' in states:
            self.critic.load_state_dict(states['critic'])
        if 'optimizer' in states and hasattr(self, 'optimizer'):
            self.optimizer.load_state_dict(states['optimizer'])
        if 'optimizer_critic' in states and hasattr(self, 'optimizer_critic'):
            self.optimizer_critic.load_state_dict(states['optimizer_critic'])
        return states

    def share(self):
        shared = super().share()
        if hasattr(self, 'bucket_cnt'):
            shared['bucket_cnt'] = self.bucket_cnt

        shared['tasks'] = self.tasks
        shared['policy'] = self.policy
        shared['critic'] = self.critic

        shared['states'] = {
            'optimizer_type': self.opt['optimizer_teacher'],
            'prev_prev_valid_report': self.prev_prev_valid_report,
            'prev_valid_report': self.prev_valid_report,
            'current_valid_report': self.current_valid_report,
            'saved_actions': self.saved_actions,
            'saved_state_actions': self.saved_state_actions,
        }
        if hasattr(self, 'writer'):
            shared['writer'] = self.writer
        if hasattr(self, 'subtask_counter'):
            shared['subtask_counter'] = self.subtask_counter
        if hasattr(self, 'p_selections'):
            shared['p_selections'] = self.p_selections
        if hasattr(self, 'c_selections'):
            shared['c_selections'] = self.c_selections
        return shared

    @staticmethod
    def sort_into_bucket(val, bucket_lbs):
        """
        Returns the highest bucket such that val >= lower bound for that bucket.

        Inputs:
          val: float. The value to be sorted into a bucket.
          bucket_lbs: list of floats, sorted ascending.

        Returns:
          bucket_id: int in range(num_buckets); the bucket that val belongs to.
        """
        num_buckets = len(bucket_lbs)
        for bucket_id in range(num_buckets - 1, -1, -1):  # iterate descending
            lb = bucket_lbs[bucket_id]
            if val >= lb:
                return bucket_id
        raise ValueError('val %f is not >= any of the lower bounds: %s' %
                         (val, bucket_lbs))

    def pace_function(self, states, sum_num, T=1000, c0=0.01, p=2):
        train_step = states['train_step']
        progress = self.root_p_pace(train_step, T, c0, p)
        return int(sum_num * progress)

    @staticmethod
    def root_p_pace(timestep, T=1000, c0=0.01, p=2):
        root_p = math.pow(
            timestep * (1 - math.pow(c0, p)) / T + math.pow(c0, p), 1.0 / p)
        return min(1.0, root_p)

    def act(self, observation=None, task_idx=0):
        """Send new dialog message."""
        if not hasattr(self, 'epochDone'):
            # reset if haven't yet
            self.reset()

        # get next example, action is episode_done dict if already out of exs
        action, self.epochDone = self.next_example(observation=observation,
                                                   task_idx=task_idx)
        action['id'] = self.getID()

        # remember correct answer if available
        self.lastY = action.get('labels', action.get('eval_labels', None))
        if ((not self.datatype.startswith('train')
             or 'evalmode' in self.datatype) and 'labels' in action):
            # move labels to eval field so not used for training
            # but this way the model can use the labels for perplexity or loss
            action = action.copy()
            labels = action.pop('labels')
            if not self.opt.get('hide_labels', False):
                action['eval_labels'] = labels

        return action

    def _cry_for_missing_in_obs(self, something):
        raise RuntimeError(
            "{} is needed to include in observations to build states!".format(
                something))

    def _build_states(self, observations):
        for key in ['train_step', 'train_report', 'loss_desc', 'prob_desc']:
            if key not in observations[0]:
                self._cry_for_missing_in_obs(key)

        train_step = observations[0]['train_step']  # scala
        train_step = min(train_step / self.T, 1)
        train_report = observations[0]['train_report']
        nll_loss = train_report.get('nll_loss', 0) / 10  # scala
        loss_desc = observations[0]['loss_desc']
        loss_desc = F.normalize(loss_desc, p=2, dim=-1)

        prob_desc = observations[0]['prob_desc']
        prob_desc = F.normalize(prob_desc, p=2, dim=-1)

        if hasattr(self, 'subtask_counter'):
            subtask_progress = self.subtask_counter.values()
            max_min = max(subtask_progress) - min(subtask_progress)
            subtask_progress = [
                (item - min(subtask_progress)) / max_min if max_min > 0 else 0
                for item in subtask_progress
            ]
        else:
            subtask_progress = [0]
        subtask_progress = torch.FloatTensor(subtask_progress)
        if self.use_cuda:
            subtask_progress = subtask_progress.cuda()

        prev_valid_report = self.prev_valid_report
        if prev_valid_report is None:
            prev_valid_report = {}

        bleu = prev_valid_report.get('bleu', 0)
        valid_nll_loss = prev_valid_report.get('nll_loss', 0) / 10
        dist_1_ratio = prev_valid_report.get('dist_1_ratio', 0)
        dist_2_ratio = prev_valid_report.get('dist_2_ratio', 0)
        dist_3_ratio = prev_valid_report.get('dist_3_ratio', 0)
        embed_avg = prev_valid_report.get('embed_avg', 0)
        embed_greedy = prev_valid_report.get('embed_greedy', 0)
        embed_extrema = prev_valid_report.get('embed_extrema', 0)
        embed_coh = prev_valid_report.get('embed_coh', 0)
        intra_dist_1 = prev_valid_report.get('intra_dist_1', 0) / 10
        intra_dist_2 = prev_valid_report.get('intra_dist_2', 0) / 10
        intra_dist_3 = prev_valid_report.get('intra_dist_3', 0) / 10
        response_length = prev_valid_report.get(
            'response_length', 0) / self.opt.get('label_truncate', 100)
        # sent_entropy_uni = prev_valid_report.get('sent_entropy_uni', 0) / 100
        # sent_entropy_bi = prev_valid_report.get('sent_entropy_bi', 0) / 100
        # sent_entropy_tri = prev_valid_report.get('sent_entropy_tri', 0) / 100
        word_entropy_uni = prev_valid_report.get('word_entropy_uni', 0) / 100
        word_entropy_bi = prev_valid_report.get('word_entropy_bi', 0) / 100
        word_entropy_tri = prev_valid_report.get('word_entropy_tri', 0) / 100
        states = torch.FloatTensor([
            train_step,
            nll_loss,
            bleu,
            valid_nll_loss,
            dist_1_ratio,
            dist_2_ratio,
            dist_3_ratio,
            embed_avg,
            embed_greedy,
            embed_extrema,
            embed_coh,
            intra_dist_1,
            intra_dist_2,
            intra_dist_3,
            response_length,
            # sent_entropy_uni, sent_entropy_bi, sent_entropy_tri,
            word_entropy_uni,
            word_entropy_bi,
            word_entropy_tri
        ])
        if self.use_cuda:
            states = states.cuda()
        states = torch.cat([states, loss_desc, prob_desc, subtask_progress],
                           dim=-1).unsqueeze(dim=0)
        return states

    def __uniform_weights(self):
        w = 1 / len(self.tasks)
        weights = torch.FloatTensor([w] * len(self.tasks))
        if self.use_cuda:
            weights = weights.cuda()
        return weights.unsqueeze(dim=0)

    def __load_training_batch(self, observations):
        if observations and len(
                observations) > 0 and observations[0] and self.is_combine_attr:
            if not self.random_policy:
                with torch.no_grad():
                    current_states = self._build_states(observations)
                action_probs = self.policy(current_states)
                if self.action_log_time.time() > self.log_every_n_secs and len(
                        self.tasks) > 1:
                    with torch.no_grad():
                        # log the action distributions
                        action_p = ','.join([
                            str(round_sigfigs(x, 4))
                            for x in action_probs[0].data.tolist()
                        ])
                        log = '[ {} {} ]'.format('Action probs:', action_p)
                        print(log)
                        self.action_log_time.reset()
                sample_from = Categorical(action_probs[0])
                action = sample_from.sample()
                train_step = observations[0]['train_step']
                self.saved_actions[train_step] = sample_from.log_prob(action)
                self.saved_state_actions[train_step] = torch.cat(
                    [current_states, action_probs], dim=1)
                selected_task = action.item()
                self.subtask_counter[self.subtasks[selected_task]] += 1

                probs = action_probs[0].tolist()
                selection_report = {}
                for idx, t in enumerate(self.subtasks):
                    selection_report['p_{}'.format(t)] = probs[idx]
                    self.p_selections[t].append(probs[idx])
                    selection_report['c_{}'.format(
                        t)] = self.subtask_counter[t]
                    self.c_selections[t].append(self.subtask_counter[t])
                self.writer.add_metrics(setting='Teacher/task_selection',
                                        step=train_step,
                                        report=selection_report)
            else:
                selected_task = random.choice(range(len(self.tasks)))
                self.subtask_counter[self.subtasks[selected_task]] += 1
        else:
            selected_task = 0

        return self.__load_batch(observations, task_idx=selected_task)

    def __load_batch(self, observations, task_idx=0):
        if observations is None:
            observations = [None] * self.bsz
        bsz = len(observations)

        batch = []
        # Sample from multiple tasks using the policy net
        for idx in range(bsz):
            batch.append(self.act(observations[idx], task_idx=task_idx))
        return batch

    def batch_act(self, observations):
        """
        Returns an entire batch of examples instead of just one.
        """
        if not hasattr(self, 'epochDone'):
            # reset if haven't yet
            self.reset()
        if self.opt['datatype'] == 'train':
            batch = self.__load_training_batch(observations)
        else:
            batch = self.__load_batch(observations)

        # pad batch
        if len(batch) < self.bsz:
            batch += [{
                'episode_done': True,
                'id': self.getID()
            }] * (self.bsz - len(batch))

        # remember correct answer if available (for padding, None)
        for i, ex in enumerate(batch):
            if 'labels' in ex:
                labels = ex['labels']
                self.lastYs[i] = labels
                if not self.datatype.startswith(
                        'train') or 'evalmode' in self.datatype:
                    del ex['labels']
                    if not self.opt.get('hide_labels', False):
                        ex['eval_labels'] = labels
            else:
                self.lastYs[i] = ex.get('eval_labels', None)

        return batch

    def next_example(self, observation=None, task_idx=0):
        """
        Returns the next example.

        If there are multiple examples in the same episode, returns the next
        one in that episode. If that episode is over, gets a new episode index
        and returns the first example of that episode.
        """
        if self.stream:
            action, epoch_done = self.tasks[task_idx].get()
        else:
            if self.episode_done:
                self.episode_idx = self.next_episode_idx()
                self.entry_idx = 0
            else:
                self.entry_idx += 1

            if self.episode_idx >= self.num_episodes():
                return {'episode_done': True}, True

            if observation is None or self.opt['datatype'] != 'train':
                # The first step of the training or validation mode
                sampled_episode_idx = self.episode_idx
                sampled_entry_idx = self.entry_idx
            else:
                # --------------- pick the sample according to the pace function -----------
                pace_by = self.opt.get('pace_by', 'sample')

                if pace_by == 'sample':
                    sum_num = self.num_episodes()
                elif pace_by == 'bucket':
                    sum_num = len(self.bucket_cnt)
                else:
                    raise ValueError('pace_by must be {} or {}!'.format(
                        'sample', 'bucket'))

                states4pace_func = observation
                if hasattr(self, 'subtask_counter'):
                    states4pace_func = {
                        'train_step':
                        self.subtask_counter[self.subtasks[task_idx]]
                    }

                threshold = self.pace_function(states4pace_func, sum_num,
                                               self.T, self.c0, self.p)
                if pace_by == 'sample':
                    stop_step = threshold
                elif pace_by == 'bucket':
                    stop_step = sum(self.bucket_cnt[:threshold])
                else:
                    raise ValueError('pace_by must be {} or {}!'.format(
                        'sample', 'bucket'))

                stop_step = self.num_episodes(
                ) if stop_step > self.num_episodes() else stop_step
                # sampled_episode_idx = random.choice(list(range(self.num_episodes()))[:stop_step])
                sampled_episode_idx = np.random.choice(stop_step)
                sampled_entry_idx = 0  # make sure the episode only contains one entry

                if self.anti:
                    sampled_episode_idx = self.num_episodes(
                    ) - 1 - sampled_episode_idx

            if self.count_sample:
                self.sample_counter[
                    self.subtasks[task_idx]][sampled_episode_idx] += 1

            ex = self.get(sampled_episode_idx,
                          sampled_entry_idx,
                          task_idx=task_idx)

            if observation is None or self.opt['datatype'] != 'train':
                self.episode_done = ex.get('episode_done', False)
                if (not self.random and self.episode_done
                        and self.episode_idx + self.opt.get("batchsize", 1) >=
                        self.num_episodes()):
                    epoch_done = True
                else:
                    epoch_done = False
            else:
                # in the setting of curriculum leaning, samples are not uniformly
                # picked from the training set, so, the epoch records here make no sense.
                epoch_done = False

            action = ex

        return action, epoch_done

    def get(self, episode_idx, entry_idx=0, task_idx=0):
        return self.tasks[task_idx].get(episode_idx, entry_idx)[0]

    def update_params(self):
        self._number_teacher_updates += 1
        if self.opt.get('gradient_clip_teacher', -1) > 0:
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(),
                                           self.opt['gradient_clip_teacher'])

        self.optimizer.step()

    def update_critic_params(self):
        if self.opt.get('gradient_clip_teacher', -1) > 0:
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                           self.opt['gradient_clip_teacher'])

        self.optimizer_critic.step()

    def receive_metrics(self, metrics_dict):
        if self.is_combine_attr and not self.random_policy:
            assert self.reward_metric in metrics_dict, '{} is not in the metrics_dict!'.format(
                self.reward_metric)
            self.prev_prev_valid_report = self.prev_valid_report
            self.prev_valid_report = self.current_valid_report
            self.current_valid_report = metrics_dict
            delt_reward = None
            if self.prev_prev_valid_report and self.prev_valid_report and self.current_valid_report:
                delt_reward1 = self.current_valid_report[
                    self.reward_metric] - self.prev_valid_report[
                        self.reward_metric]
                delt_reward0 = self.prev_valid_report[
                    self.reward_metric] - self.prev_prev_valid_report[
                        self.reward_metric]
                if self.reward_metric_mode == 'min':
                    delt_reward1 = -delt_reward1
                    delt_reward0 = -delt_reward0
                delt_reward = delt_reward1 / (delt_reward0 + 1e-6) - 1
            if delt_reward and len(self.saved_actions) > 0 and len(
                    self.saved_state_actions) > 0:
                reward = torch.clamp(torch.FloatTensor([delt_reward]), -10, 10)
                if self.use_cuda:
                    reward = reward.cuda()

                with torch.no_grad():
                    batch_state_actions = torch.cat(list(
                        self.saved_state_actions.values()),
                                                    dim=0)
                    if self.use_cuda:
                        batch_state_actions = batch_state_actions.cuda()
                    estimate_rewards = self.critic(
                        batch_state_actions).squeeze()
                    advantages = reward - estimate_rewards

                    # rescale the rewards by ranking
                    episode_len = len(advantages)
                    ranks = torch.FloatTensor(
                        list(
                            reversed(
                                ss.rankdata(advantages.cpu(),
                                            method='dense')))).unsqueeze(dim=1)
                    rescaled_rewards = torch.sigmoid(
                        12 * (0.5 - ranks / episode_len))

                rescaled_rewards = [r.item() for r in rescaled_rewards]
                policy_loss = []
                idx = 0
                for model_train_step, log_prob in self.saved_actions.items():
                    policy_loss.append(-log_prob.unsqueeze(dim=0) *
                                       rescaled_rewards[idx])
                    idx += 1
                policy_loss = torch.cat(policy_loss).sum()

                # regularization term regarding action distribution
                bsz = batch_state_actions.size(0)
                action_probs = torch.cat(list(
                    self.saved_state_actions.values()),
                                         dim=0).narrow(1, self.state_dim,
                                                       self.action_dim)
                action_ent = torch.sum(
                    -action_probs * torch.log(action_probs)) / bsz

                self.policy.train()
                self.optimizer.zero_grad()
                policy_loss = policy_loss + self.opt.get('reg_action',
                                                         0.001) * (-action_ent)
                policy_loss.backward()
                self.update_params()

                # lr_scheduler step on teacher loss
                policy_loss_item = policy_loss.item()
                if self.opt.get('optimizer_teacher', '') == 'sgd':
                    self.scheduler.step(policy_loss_item)

                # training on the critic
                self.critic.train()
                self.optimizer_critic.zero_grad()

                batch_values = self.critic(batch_state_actions)
                critic_target = torch.FloatTensor(bsz, 1)
                critic_target = critic_target.fill_(reward.item())
                if self.use_cuda:
                    critic_target = critic_target.cuda()
                critic_loss = self.critic_criterion(batch_values,
                                                    critic_target)
                critic_loss.backward()
                self.update_critic_params()
                critic_loss_item = critic_loss.item()
                if self.opt.get('optimizer_teacher', '') == 'sgd':
                    self.scheduler_critic.step(critic_loss_item)

                # log something
                print(
                    '[ reward: {}; mean_advantage_reward: {}; policy loss: {};'
                    ' critic loss: {}; action ent: {}; episode length: {} ]'.
                    format(reward.item(), np.mean(advantages.tolist()),
                           policy_loss_item, critic_loss_item,
                           action_ent.item(), len(self.saved_actions)))

                report = {
                    'reward': reward.item(),
                    'mean_advantage_reward': np.mean(advantages.tolist()),
                    'policy_loss': policy_loss_item,
                    'critic_loss': critic_loss_item,
                    'action_ent': action_ent.item(),
                }
                self.writer.add_metrics(setting='Teacher/receive_metrics',
                                        step=self._number_teacher_updates,
                                        report=report)
                # clear history actions
                self.saved_actions.clear()
                self.saved_state_actions.clear()

    def state_dict(self):
        """
        Get the state dict for saving

        TODO: save more teacher-related states for reloading
        """
        states = {}
        if hasattr(self, 'policy'):  # save model params
            if hasattr(self.policy, 'module'):
                # did we wrap in a DistributedDataParallel
                states['policy'] = self.policy.module.state_dict()
            else:
                states['policy'] = self.policy.state_dict()

        if hasattr(self, 'critic'):  # save model params
            if hasattr(self.critic, 'module'):
                # did we wrap in a DistributedDataParallel
                states['critic'] = self.critic.module.state_dict()
            else:
                states['critic'] = self.critic.state_dict()

        if hasattr(self, 'optimizer'):  # save optimizer params
            states['optimizer'] = self.optimizer.state_dict()
            states['optimizer_type'] = self.opt['optimizer_teacher']
        if hasattr(self, 'optimizer_critic'):
            states['optimizer_critic'] = self.optimizer_critic.state_dict()

        if getattr(self, 'scheduler', None):
            states['lr_scheduler'] = self.scheduler.state_dict()
        if getattr(self, 'scheduler_critic', None):
            states['lr_scheduler_critic'] = self.scheduler_critic.state_dict()

        states['prev_prev_valid_report'] = self.prev_prev_valid_report
        states['prev_valid_report'] = self.prev_valid_report
        states['current_valid_report'] = self.current_valid_report
        states['saved_actions'] = self.saved_actions
        states['saved_state_actions'] = self.saved_state_actions

        states['_number_teacher_updates'] = self._number_teacher_updates

        return states

    def save(self, path=None):
        if path:
            teacher_path = path
        else:
            model_file = self.opt.get('model_file', None)
            if model_file:
                teacher_path = model_file + '.teacher'
            else:
                teacher_path = None

        if teacher_path:
            states = self.state_dict()
            if states:
                with open(teacher_path, 'wb') as write:
                    torch.save(states, write)
                # save opt file
                with open(teacher_path + '.opt', 'w',
                          encoding='utf-8') as handle:
                    json.dump(self.opt, handle)
                    # for convenience of working with jq, make sure there's a newline
                    handle.write('\n')

            if self.count_sample:
                # save sample count info
                for task_name, task_val in self.sample_counter.items():
                    with open(teacher_path +
                              '.sample_count.{}'.format(task_name),
                              'w',
                              encoding='utf-8') as f:
                        f.write('\n'.join([str(item) for item in task_val]))

            self.write_selections('p_selections', teacher_path)
            self.write_selections('c_selections', teacher_path)

    def write_selections(self, selections, teacher_path):
        if hasattr(self, selections):
            with open(teacher_path + '.{}'.format(selections),
                      'w',
                      encoding='utf-8') as f:
                f.write('\t'.join(self.subtasks))
                f.write('\n')
                for idx in range(
                        len(getattr(self, selections)[self.subtasks[0]])):
                    p_line = []
                    for t in self.subtasks:
                        p_line.append(str(getattr(self, selections)[t][idx]))
                    f.write('\t'.join(p_line))
                    f.write('\n')
예제 #24
0
class TrainLoop():
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt[
            'max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt[
            'log_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt[
            'validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt[
            'save_every_n_secs'] > 0 else float('inf')
        self.val_every_n_epochs = opt['validation_every_n_epochs'] if opt[
            'validation_every_n_epochs'] > 0 else float('inf')
        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

    def validate(self):
        opt = self.opt
        # run evaluation on valid set
        valid_report, self.valid_world = run_eval(self.agent,
                                                  opt,
                                                  'valid',
                                                  opt['validation_max_exs'],
                                                  valid_world=self.valid_world)

        # logging
        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('valid',
                                    int(math.floor(self.train_time.time())),
                                    valid_report)
        # saving
        if opt.get('model_file') and opt.get('save_after_valid'):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.agent.save(opt['model_file'] + '.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid:
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.agent.save(opt['model_file'])
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[
                    'validation_cutoff']:
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if opt['validation_patience'] > 0 and self.impatience >= opt[
                'validation_patience']:
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def log(self):
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report(compute_time=True)
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(math.floor(self.train_time.time())))
        total_exs = self.world.get_total_exs()
        logs.append('total_exs:{}'.format(total_exs))

        exs_per_ep = self.world.num_examples()
        if exs_per_ep:
            logs.append('epochs:{}'.format(round(total_exs / exs_per_ep, 2)))

        if 'time_left' in train_report:
            logs.append('time_left:{}s'.format(
                math.floor(train_report.pop('time_left', ""))))

        log = '[ {} ] {}'.format(' '.join(logs), train_report)
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('train', int(logs[1].split(":")[1]),
                                    train_report)

    def train(self):
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # check counters and timers
                if world.get_total_epochs() >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, self.train_time.time()))
                    break
                if self.train_time.time() > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(
                        self.train_time.time()))
                    break
                if self.log_time.time() > self.log_every_n_secs:
                    self.log()
                if self.validate_time.time() > self.val_every_n_secs:
                    stop_training = self.validate()
                    if stop_training:
                        break
                if world.get_total_epochs(
                ) - self.last_valid_epoch >= self.val_every_n_epochs:
                    stop_training = self.validate()
                    self.last_valid_epoch = world.get_total_epochs()
                    if stop_training:
                        break
                if self.save_time.time() > self.save_every_n_secs and opt.get(
                        'model_file'):
                    print("[ saving model checkpoint: " + opt['model_file'] +
                          ".checkpoint ]")
                    self.agent.save(opt['model_file'] + '.checkpoint')
                    self.save_time.reset()

        if not self.saved:
            # save agent
            self.agent.save(opt['model_file'])
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True)
        t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True)
        v_world.shutdown()
        t_world.shutdown()
        return v_report, t_report
예제 #25
0
class TrainLoop:
    """
    TrainLoop contains the core training loop logic.
    """
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (opt['load_from_checkpoint'] and opt.get('model_file')
                and os.path.isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.')
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))
        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))
        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))
        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))
        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))
        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.valid_reports = []
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    trainstats_suffix):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and os.path.isfile(
                            opt['model_file'] + '.best_valid'):
                        with open(opt['model_file'] + ".best_valid", 'r') as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)

    def save_model(self, suffix=None):
        """
        Save the model to disk, possibly with a suffix.
        """
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return
        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _safe_report(self, report):
        return {
            k: v.value() if isinstance(v, Metric) else v
            for k, v in report.items()
        }

    def _save_train_stats(self, suffix=None):
        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with open(fn, 'w') as f:
            json.dump(
                {
                    'parleys':
                    self.parleys,
                    'train_time':
                    self.train_time.time(),
                    'total_epochs':
                    (self._preempted_epochs +
                     num_workers() * self.world.get_total_epochs()),
                    'impatience':
                    self.impatience,
                    'valid_reports':
                    [self._safe_report(v) for v in self.valid_reports],
                    'best_valid':
                    self.best_valid,
                },
                f,
            )

    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid')

        # run evaluation on valid set
        # TODO(MW): replace sync_object with self._sync_metrics. You'll need some
        # logic to handle 'validation_max_exs' properly
        valid_report = run_eval(self.valid_worlds, opt, 'valid',
                                opt['validation_max_exs'])
        v = valid_report.copy()
        v['train_time'] = self.train_time.time()
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
            # flush on a validation
            self.tb_logger.flush()
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        new_valid = valid_report[opt['validation_metric']]

        if isinstance(new_valid, Metric):
            new_valid = new_valid.value()

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'],
                new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else '',
            ))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def _sync_metrics(self, metrics):
        """
        Sync training metrics across workers.

        A handful of special cases are handled as exceptions, and the remaining metrics
        are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return aggregate_unnamed_reports(all_versions)

    def _compute_eta(self, epochs_completed, time_elapsed):
        """
        Compute the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        return eta

    def log(self):
        """
        Output a training log entry.
        """
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        train_report = self._sync_metrics(train_report)
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(np.floor(self.train_time.time())))
        logs.append('total_exs:{}'.format(self._total_exs))

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append('epochs:{}'.format(round(self._total_epochs, 2)))

        time_left = self._compute_eta(self._total_epochs,
                                      self.train_time.time())
        if time_left is not None:
            logs.append('time_left:{}s'.format(max(0, np.ceil(time_left))))

        log = '[ {} ] {}'.format(' '.join(logs), nice_report(train_report))
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)

    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        world = self.world
        count = 0
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for "
                            "distributed mode")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not "
                                "supported for distributed mode")
                        break
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
예제 #26
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    # Get command line arguments
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--world-logs',
        type=str,
        default='',
        help='Saves a jsonl file of the world logs.'
        'Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    parser.add_argument(
        '--area-under-curve-digits',
        '-auc',
        type=int,
        default=-1,
        help=
        'a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric',
    )
    parser.add_argument(
        '--area-under-curve-class',
        '-auclass',
        type=str,
        default=None,
        nargs='*',
        help='the name(s) of the class to calculate the auc for',
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    WorldLogger.add_cmdline_args(parser, partial_opt=None)
    TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
    parser.set_params(datatype='valid')
    return parser
예제 #27
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (
            opt['load_from_checkpoint']
            and opt.get('model_file')
            and PathManager.exists(opt['model_file'] + '.checkpoint')
        ):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.'
            )
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            logging.info("building dictionary first...")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.agent.opt.log()
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()

        self.parleys = 0
        self._train_steps = 0
        self._last_log_steps = 0
        self.update_freq = opt.get('update_freq', 1)

        self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True)
        self.max_train_time = _num_else_inf(
            opt, 'max_train_time', distributed_warn=True
        )
        self.max_train_steps = _num_else_inf(opt, 'max_train_steps')
        self.log_every_n_secs = _num_else_inf(
            opt, 'log_every_n_secs', distributed_warn=True
        )
        self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps')
        self.val_every_n_secs = _num_else_inf(
            opt, 'validation_every_n_secs', distributed_warn=True
        )
        self.val_every_n_epochs = _num_else_inf(
            opt, 'validation_every_n_epochs', distributed_warn=True
        )
        self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps')
        self.save_every_n_secs = _num_else_inf(
            opt, 'save_every_n_secs', distributed_warn=True
        )

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self._last_valid_steps = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.train_reports = []
        self.valid_reports = []
        self.final_valid_report = {}
        self.final_test_report = {}
        self.final_extra_valid_report = {}
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and PathManager.exists(
            opt['model_file'] + trainstats_suffix
        ):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self._train_steps = obj.get('train_steps', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if self.valid_reports:
                    self.last_valid_epoch = self.valid_reports[-1].get(
                        'total_epochs', 0.0
                    )
                self.train_reports = obj.get('train_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and PathManager.exists(
                        opt['model_file'] + '.best_valid'
                    ):
                        with PathManager.open(
                            opt['model_file'] + ".best_valid", 'r'
                        ) as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)
        if opt['wandb_log'] and is_primary_worker():
            model = self.agent.model if hasattr(self.agent, 'model') else None
            self.wb_logger = WandbLogger(opt, model)
예제 #28
0
class TrainLoop():
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get(
                    'model_file_transmitter') and opt.get(
                        'model_file_receiver'):
                opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt[
                    'model_file_receiver'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=False)

        # Create model and assign it to the specified task
        print("[ create meta-agent ... ]")
        self.agent = create_agent(opt)
        print("[ create agent A ... ]")
        shared = self.agent.share()
        self.agent_a = create_agent_from_shared(shared)
        self.agent_a.set_id(suffix=' A')
        print("[ create agent B ... ]")
        self.agent_b = create_agent_from_shared(shared)
        # self.agent_b = create_agent(opt)
        self.agent_b.set_id(suffix=' B')
        # self.agent_a.copy(self.agent, 'transmitter')
        # self.agent_b.copy(self.agent, 'transmitter')
        self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b])

        # TODO: if batch, it is also not parallel
        # self.world = BatchSelfPlayWorld(opt, self_play_world)

        self.train_time = Timer()
        self.train_dis_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys_episode = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt[
            'max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt[
            'log_every_n_secs'] > 0 else float('inf')
        self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt[
            'train_display_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt[
            'validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt[
            'save_every_n_secs'] > 0 else float('inf')
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file_transmitter') and os.path.isfile(
                opt['model_file_transmitter'] + '.best_valid'):
            with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

    def validate(self):
        opt = self.opt
        valid_report, self.valid_world = run_eval(self.agent,
                                                  opt,
                                                  'valid',
                                                  opt['validation_max_exs'],
                                                  valid_world=self.valid_world)
        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('valid', self.parleys_episode,
                                    valid_report)
        if opt.get('model_file_transmitter') and opt.get('save_after_valid'):
            print("[ saving transmitter checkpoint: " +
                  opt['model_file_transmitter'] + ".checkpoint ]")
            self.agent.save(component='transmitter')
        # if opt.get('model_file_receiver') and opt.get('save_after_valid'):
        #     print("[ saving receiver checkpoint: " + opt['model_file_receiver'] + ".checkpoint ]")
        #     self.agent.save(component='receiver')
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]
        if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid:
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                # the fine-tuned transmitter part is actually what we want for PSquare bot
                self.agent.save()
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True

            if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[
                    'validation_cutoff']:
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()
        if 0 < opt['validation_patience'] <= self.impatience:
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def log(self):
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(math.floor(self.train_time.time())))
        logs.append('parleys:{}'.format(self.parleys_episode))

        if 'time_left' in train_report:
            logs.append('time_left:{}s'.format(
                math.floor(train_report.pop('time_left', ""))))
        if 'num_epochs' in train_report:
            logs.append('num_epochs:{}'.format(
                train_report.pop('num_epochs', '')))
        log = '[ {} ] {}'.format(' '.join(logs), train_report)
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] is True:
            self.writer.add_metrics('train', self.parleys_episode,
                                    train_report)

    def train(self):
        # print('#### Validating at {} training episode '.format(self.parleys_episode))
        # self.validate()
        opt = self.opt
        world = self.world
        with world:
            while True:
                self.parleys_episode += 1
                if self.parleys_episode % 100 == 0:
                    print('#### Training {} episode '.format(
                        self.parleys_episode))

                if self.train_dis_time.time() > self.train_dis_every_n_secs:
                    is_display = True
                    # clear to zero
                    self.train_dis_time.reset()
                else:
                    is_display = False

                world.parley_episode(is_training=True, is_display=is_display)

                if world.get_total_epochs() >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, self.train_time.time()))
                    break

                if self.train_time.time() > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(
                        self.train_time.time()))
                    break

                if self.log_time.time() > self.log_every_n_secs:
                    self.log()

                if self.validate_time.time() > self.val_every_n_secs:
                    print('#### Validating at {} training episode '.format(
                        self.parleys_episode))
                    stop_training = self.validate()
                    if stop_training:
                        break

                if self.save_time.time() > self.save_every_n_secs:
                    if opt.get('model_file_transmitter'):
                        print("[ saving transmitter checkpoint: " +
                              opt['model_file_transmitter'] + ".checkpoint ]")
                        self.agent.save(opt['model_file_transmitter'] +
                                        '.checkpoint',
                                        component='transmitter')
                    if opt.get('model_file_receiver'):
                        print("[ saving receiver checkpoint: " +
                              opt['model_file_receiver'] + ".checkpoint ]")
                        self.agent.save(opt['model_file_receiver'] +
                                        '.checkpoint',
                                        component='receiver')
                    self.save_time.reset()

        if not self.saved:
            # save agent
            # self.agent.save(component='transmitter')
            self.agent.save()
            # self.agent.save(component='receiver') # TODO: API for save all components
        elif opt.get('model_file_transmitter') and opt.get(
                'model_file_receiver'
        ):  # TODO: check if both components are necessary
            # reload best validation model
            self.agent = create_agent(opt)

        v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True)
        t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True)
        v_world.shutdown()
        t_world.shutdown()
        return v_report, t_report
예제 #29
0
def setup_args(parser=None) -> ParlaiParser:
    """
    Build the ParlAI parser, adding command line args if necessary.

    :param ParlaiParser parser:
        Preexisting parser to append options to. Will be created if needed.

    :returns:
        the ParlaiParser with CLI options added.
    """
    if parser is None:
        parser = ParlaiParser(True, True, 'Train a model')
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument(
        '-et',
        '--evaltask',
        help='task to use for valid/test (defaults to the one used for training)',
    )
    train.add_argument(
        '--final-extra-opt',
        type=str,
        default='',
        help="A '.opt' file that is used for final eval. Useful for setting skip-generation to false. 'datatype' must be included as part of the opt.",
    )
    train.add_argument(
        '--eval-batchsize',
        type=int,
        hidden=True,
        help='Eval time batch size (defaults to same as -bs)',
    )
    train.add_argument(
        '--eval-dynamic-batching',  # FIXME: see https://github.com/facebookresearch/ParlAI/issues/3367
        default=None,
        type='nonestr',
        choices={None, 'off', 'full', 'batchsort'},
        help=(
            'Set dynamic batching at evaluation time. Set to off for '
            'train-only dynamic batching. Set to none (default) to use same '
            'setting as --dynamic-batching.'
        ),
    )
    train.add_argument(
        '--num-workers',
        default=0,
        type=int,
        help='Number of background workers (training only)',
    )
    train.add_argument('--display-examples', type='bool', default=False, hidden=True)
    train.add_argument('-eps', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
    train.add_argument(
        '-tstep',
        '--max-train-steps',
        '--max-lr-steps',
        type=int,
        default=-1,
        help='End training after n model updates',
    )
    train.add_argument('-ltim', '--log-every-n-secs', type=float, default=-1)
    train.add_argument(
        '-lstep',
        '--log-every-n-steps',
        type=int,
        default=50,
        help='Log every n training steps',
    )
    train.add_argument(
        '-vtim',
        '--validation-every-n-secs',
        type=float,
        default=-1,
        help='Validate every n seconds. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-vstep',
        '--validation-every-n-steps',
        type=int,
        default=-1,
        help='Validate every n training steps. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-stim',
        '--save-every-n-secs',
        type=float,
        default=-1,
        help='Saves the model to model_file.checkpoint after '
        'every n seconds (default -1, never).',
    )
    train.add_argument(
        '-sval',
        '--save-after-valid',
        type='bool',
        default=False,
        help='Saves the model to model_file.checkpoint after '
        'every validation (default %(default)s).',
    )
    train.add_argument(
        '-veps',
        '--validation-every-n-epochs',
        type=float,
        default=-1,
        help='Validate every n epochs. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-vme',
        '--validation-max-exs',
        type=int,
        default=-1,
        hidden=True,
        help='max examples to use during validation (default -1 uses all)',
    )
    train.add_argument(
        '--short-final-eval',
        default=False,
        hidden=True,
        type='bool',
        help='If true, obeys --validation-max-exs in the final '
        'validation and test evaluations.',
    )
    train.add_argument(
        '-vp',
        '--validation-patience',
        type=int,
        default=10,
        help=(
            'number of iterations of validation where result'
            ' does not improve before we stop training'
        ),
    )
    train.add_argument(
        '-vmt',
        '--validation-metric',
        default='accuracy',
        help='key into report table for selecting best validation',
    )
    train.add_argument(
        '-vmm',
        '--validation-metric-mode',
        type=str,
        choices=['max', 'min'],
        help='the direction in which to optimize the validation metric, i.e. maximize or minimize',
    )
    train.add_argument(
        '-vcut',
        '--validation-cutoff',
        type=float,
        default=1.0,
        hidden=True,
        help='value at which training will stop if exceeded by metric',
    )
    train.add_argument(
        '-lfc',
        '--load-from-checkpoint',
        type='bool',
        default=True,
        hidden=True,
        help='load model from checkpoint if available',
    )
    train.add_argument(
        '-vshare',
        '--validation-share-agent',
        default=False,
        hidden=True,
        help='use a shared copy of the agent for validation. '
        'this will eventually default to True, but '
        'currently defaults to False.',
    )
    train.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    train.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    train.add_argument(
        '--world-logs',
        type=str,
        default='',
        help='Saves a jsonl file of the world logs.'
        'Set to the empty string to not save at all.',
    )
    train.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    WorldLogger.add_cmdline_args(parser, partial_opt=None)
    TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
    WandbLogger.add_cmdline_args(parser, partial_opt=None)

    parser = setup_dict_args(parser)
    return parser
class EntNetAgent(TorchAgent):
    @staticmethod
    def add_cmdline_args(argparser):
        TorchAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group("EntNet Arguments")

        agent.add_argument("-wt",
                           "--weight-tying",
                           type=str,
                           default="layer-wise",
                           help="Type of weight tying")
        agent.add_argument("-nmh",
                           "--num-memory-hops",
                           type=int,
                           default=3,
                           help="Number of memory hops")

        EntNetAgent.dictionary_class().add_cmdline_args(argparser)

        return agent

    def __init__(self, opt, shared=None):

        super().__init__(opt, shared)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

        self.dictionnary_size = 177
        self.embedding_dim = 100
        self.batch_size = opt["batchsize"]

        self.criterion = nn.CrossEntropyLoss()

        def weight_init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)

        self.recurrent_entity_network = RecurrentEntityNetwork(
            self.dictionnary_size, self.embedding_dim, sequence_length=7)
        self.recurrent_entity_network.apply(weight_init)
        self.optimizer = optim.Adam(self.recurrent_entity_network.parameters())
        #self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 25, 0.5)
        self.batch_iter = 0

    def vectorize(self, *args, **kwargs):
        """Override options in vectorize from parent."""
        kwargs['add_start'] = False
        kwargs['add_end'] = False
        kwargs['split_lines'] = True
        return super().vectorize(*args, **kwargs)

    def train_step(self, batch):

        #self.scheduler.step()

        self.recurrent_entity_network.train()

        questions, answers = batch.text_vec, batch.label_vec
        contexts = padded_3d(batch.memory_vecs)

        loss = 0
        self.optimizer.zero_grad()

        output = self.recurrent_entity_network(questions, contexts)
        pred = output.argmax(dim=1)

        loss = self.criterion(output, answers.squeeze(1))
        losses.append(loss.item())
        self.writer.add_scalar("data/loss", loss, self.batch_iter)

        for name, param in self.recurrent_entity_network.named_parameters():
            self.writer.add_histogram(name,
                                      param.clone().cpu().data.numpy(),
                                      self.batch_iter)
            #self.writer.add_histogram(name + "_grad", param.grad.clone().cpu().data.numpy(), self.batch_iter)


#            for memory_hop_layer in self.stacked_memory_hop.memory_hop_layers:
#                for name_in, param_in in memory_hop_layer.named_parameters():
#                    self.writer.add_histogram(name_in, param_in.clone().cpu().data.numpy(), self.batch_iter)
#                    #self.writer.add_histogram(name_in + "_grad", param_in.grad.clone().cpu().data.numpy(), self.batch_iter)

#print("Loss : ", loss.item())
#self.writer.add_histogram("predictions", output.clone().cpu().data.numpy(), self.batch_iter)
        loss.backward(retain_graph=True)
        self.optimizer.step()

        self.batch_iter += 1

        return Output(self.dict.vec2txt(pred).split(" "))

    def eval_step(self, batch):
        questions = batch.text_vec
        contexts = padded_3d(batch.memory_vecs)

        if contexts.shape[0] != self.batch_size:
            return Output(
                self.dict.vec2txt(
                    np.random.choice(self.dictionnary_size,
                                     size=contexts.shape[0])).split(" "))

        output = self.recurrent_entity_network(questions, contexts)
        pred = output.argmax(dim=1)

        return Output(self.dict.vec2txt(pred).split(" "))