Ejemplo n.º 1
0
    def predict(self, test_task_iters: BatchPreparationPipeline, model: Model,
                metric_reporter: MetaLearnMetricReporter):

        for meta_batch in test_task_iters:
            support, target, context = meta_batch
            for (s_inputs,
                 t_inputs), (s_targets,
                             t_targets), (s_context, t_context) in zip(
                                 support, target, context):
                task = t_context['task_id'][0]
                model.train()
                model.contextualize(s_context)
                model(*s_inputs,
                      responses=s_targets)  # model remembers responses
                model.eval()

                with torch.no_grad():
                    resps, resp_lens = model(
                        *t_inputs
                    )  # gets predcited response in embedded tensor. and length of it.

                    yield dict(task=task,
                               resps=resps,
                               resp_lens=resp_lens,
                               s_inputs=s_inputs,
                               s_targets=s_targets,
                               s_context=s_context,
                               t_inputs=t_inputs,
                               t_targets=t_targets,
                               t_context=t_context)
Ejemplo n.º 2
0
    def test(self, test_task_iters: BatchPreparationPipeline, model: Model,
             metric_reporter: MetaLearnMetricReporter):

        for mbidx, meta_batch in enumerate(test_task_iters):
            support, target, context = meta_batch
            for (s_inputs,
                 t_inputs), (s_targets,
                             t_targets), (s_context, t_context) in zip(
                                 support, target, context):
                task = t_context['task_id'][0]
                model.train()
                model.contextualize(s_context)
                model(*s_inputs,
                      responses=s_targets)  # model remembers responses
                model.eval()

                with torch.no_grad():
                    t_pred = model(*t_inputs)
                    t_loss = model.get_loss(t_pred, t_targets,
                                            t_context).item()

                    metric_reporter.add_batch_stats(task,
                                                    t_loss,
                                                    s_inputs,
                                                    t_predictions=t_pred,
                                                    t_targets=t_targets)

        metric_reporter.report_metric(stage=Stage.TEST, epoch=0, reset=False)
Ejemplo n.º 3
0
    def train(
        self,
        train_iter: BatchIterator,
        eval_iter: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        optimizer: torch.optim.Optimizer,
        scheduler=None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model, the model states will be modified. This function
        iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config
            optimizer (torch.optim.Optimizer): torch optimizer to be used
            scheduler (Optional[torch.optim.lr_scheduler]): learning rate scheduler,
                default is None
            training_result (Optional): only meaningful for Hogwild training. default
                is None
            rank (int): only used in distributed training, the rank of the current
                training thread, evaluation will only be done in rank 0

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        timer = time_utils.StageTimer()
        world_size = 1
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()
            world_size = cuda_utils.DISTRIBUTED_WORLD_SIZE
            if world_size > 1:
                device_id = torch.cuda.current_device()
                model = DistributedModel(
                    module=model,
                    device_ids=[device_id],
                    output_device=device_id,
                    broadcast_buffers=False,
                )
            timer.add_stage(stage="init_distributed_model")

        best_metric = None
        last_best_epoch = 0
        scheduler = self._prepare_scheduler(train_iter, scheduler)
        timer.add_stage(stage="pre_training")

        def training_pre_batch_callback():
            if world_size > 1:
                # replace optimizer.zero_grad() here to work with DDP
                # in cases where some parameters don't receive grads at each step
                # loss.backward will set grad for params in the computation graph
                # we can thus follow which params are left out and call .backward
                # on them manually
                for p in model.parameters():
                    if p.grad is not None:
                        p.grad.detach_()
                        p.grad = None
            else:
                optimizer.zero_grad()

        def training_backprop(loss, timer):
            loss.backward()
            if world_size > 1:
                # DDP fix when some parameters don't receive grads
                for p in model.parameters():
                    if p.requires_grad and p.grad is None:
                        p.backward(torch.zeros_like(p.data))
            timer.add_stage("backward")

            if scheduler:
                scheduler.step_batch()

            if self.config.max_clip_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), self.config.max_clip_norm)
            else:
                grad_norm = None

            optimizer.step()
            timer.add_stage("update_grads")
            # grad_norm could be used to check grads sync in distributed training
            return grad_norm

        time_start = time.time()
        for epoch in range(1, self.config.epochs + 1):
            if self.config.target_time_limit_seconds > 0 and epoch > 1:
                time_elapsed = time.time() - time_start
                mean_epoch_time = time_elapsed / float(epoch - 1)
                expected_next_epoch_time = time_elapsed + mean_epoch_time
                if expected_next_epoch_time > self.config.target_time_limit_seconds:
                    print(
                        f"Training stopped after {epoch - 1} epochs and "
                        f"{int(time_elapsed)} seconds, due to the target max training "
                        f"time of {self.config.target_time_limit_seconds} seconds."
                    )
                    break

            print(f"Rank {rank} worker: Starting epoch #{epoch}")
            model.train()
            lrs = (str(lr) for lr in learning_rates(optimizer))
            print(f"Learning rate(s): {', '.join(lrs)}")
            self._run_epoch(
                Stage.TRAIN,
                epoch,
                train_iter,
                model,
                metric_reporter,
                pre_batch=training_pre_batch_callback,
                backprop=training_backprop,
                rank=rank,
            )
            timer.add_stage(stage=f"epoch_train")

            model.eval(Stage.EVAL)
            with torch.no_grad():
                eval_metric = self._run_epoch(Stage.EVAL,
                                              epoch,
                                              eval_iter,
                                              model,
                                              metric_reporter,
                                              rank=rank)
            timer.add_stage(stage=f"epoch_eval")

            # Step the learning rate scheduler(s)
            if scheduler:
                assert eval_metric is not None
                scheduler.step(
                    metrics=metric_reporter.get_model_select_metric(
                        eval_metric),
                    epoch=epoch,
                )

            # choose best model.
            if metric_reporter.compare_metric(eval_metric, best_metric):
                last_best_epoch = epoch
                best_metric = eval_metric
                # Only rank = 0 trainer saves modules.
                if train_config.save_module_checkpoints and rank == 0:
                    model.save_modules(base_path=train_config.modules_save_dir,
                                       suffix=f"-ep{epoch}")

                if rank == 0:
                    print(f"Rank {rank} worker: Found a better model!")
                    model_state = model.state_dict()
                    # save to cpu to avoid multiple model copies in gpu memory
                    if cuda_utils.CUDA_ENABLED:
                        for key, state in model_state.items():
                            model_state[key] = state.cpu()
                    best_model_state = model_state
                timer.add_stage(stage=f"epoch_save/load_module")

            if self.config.early_stop_after > 0 and (
                    epoch - last_best_epoch == self.config.early_stop_after):
                print(f"Rank {rank} worker: Eval metric hasn't changed for " +
                      f"{self.config.early_stop_after} epochs. Stopping now.")
                break
            sys.stdout.flush()

        if rank == 0:
            if cuda_utils.CUDA_ENABLED:
                for key, state in best_model_state.items():
                    best_model_state[key] = state.cuda()
            model.load_state_dict(best_model_state)

        timer.report("Trainer train timer")
        return model, best_metric
Ejemplo n.º 4
0
    def train(
        self,
        train_iter: BatchIterator,
        eval_iter: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        optimizers: List[torch.optim.Optimizer],
        scheduler=None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model, the model states will be modified. This function
        iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config
            optimizers (List[torch.optim.Optimizer]): a list of torch optimizers, in
                most of the case only contains one optimizer
            scheduler (Optional[torch.optim.lr_scheduler]): learning rate scheduler,
                default is None
            training_result (Optional): only meaningful for Hogwild training. default
                is None
            rank (int): only used in distributed training, the rank of the current
                training thread, evaluation will only be done in rank 0

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()
            if cuda_utils.DISTRIBUTED_WORLD_SIZE > 1:
                device_id = torch.cuda.current_device()
                model = DistributedModel(
                    module=model,
                    device_ids=[device_id],
                    output_device=device_id,
                    broadcast_buffers=False,
                )

        best_metric = None
        last_best_epoch = 0
        best_model_state = None
        scheduler = self._prepare_scheduler(train_iter, scheduler)

        def training_pre_batch_callback():
            optimizer_zero_grad(optimizers)

        def training_backprop(loss):
            loss.backward()
            if scheduler:
                scheduler.step_batch()

            if self.config.max_clip_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), self.config.max_clip_norm)
            else:
                grad_norm = None

            optimizer_step(optimizers)
            # grad_norm could be used to check grads sync in distributed training
            return grad_norm

        for epoch in range(1, self.config.epochs + 1):
            print(f"Rank {rank} worker: Starting epoch #{epoch}")
            model.train()
            lrs = (str(lr) for lr in learning_rates(optimizers))
            print(f"Learning rate(s): {', '.join(lrs)}")

            self._run_epoch(
                Stage.TRAIN,
                epoch,
                train_iter,
                model,
                metric_reporter,
                pre_batch=training_pre_batch_callback,
                backprop=training_backprop,
                rank=rank,
            )

            model.eval(Stage.EVAL)
            eval_metric = self._run_epoch(Stage.EVAL,
                                          epoch,
                                          eval_iter,
                                          model,
                                          metric_reporter,
                                          rank=rank)

            # Step the learning rate scheduler(s)
            if scheduler:
                assert eval_metric is not None
                scheduler.step(
                    metrics=metric_reporter.get_model_select_metric(
                        eval_metric),
                    epoch=epoch,
                )

            # choose best model.
            if metric_reporter.compare_metric(eval_metric, best_metric):
                print(
                    f"Rank {rank} worker: Found a better model! Saving the model state."
                )
                last_best_epoch = epoch
                best_metric = eval_metric
                # Only rank = 0 trainer saves modules.
                if train_config.save_module_checkpoints and rank == 0:
                    model.save_modules(base_path=train_config.modules_save_dir,
                                       suffix=f"-ep{epoch}")
                best_model_state = copy.deepcopy(model.state_dict())

            if self.config.early_stop_after > 0 and (
                    epoch - last_best_epoch == self.config.early_stop_after):
                print(f"Rank {rank} worker: Eval metric hasn't changed for " +
                      f"{self.config.early_stop_after} epochs. Stopping now.")
                break
            sys.stdout.flush()

        model.load_state_dict(best_model_state)
        return model, best_metric
Ejemplo n.º 5
0
    def train(
        self,
        train_task_iters: Optional[BatchPreparationPipeline],
        eval_task_iters: BatchPreparationPipeline,
        model: Model,
        metric_reporter: MetaLearnMetricReporter,
        train_config: PyTextConfig,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:

        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()

        best_model_path = None

        # Start outer loop (meta learner "epochs") #############################################
        if not train_task_iters:
            LOG.warning("Model does not need meta-training")
        else:
            for epoch in range(1, 2):  # single epoch
                for bidx, (support, target,
                           context) in zip(range(100), train_task_iters):
                    for (s_inputs,
                         t_inputs), (s_targets,
                                     t_targets), (s_context, t_context) in zip(
                                         support, target, context):
                        task = t_context['task_id'][0]

                        # Adapt the model using the support set
                        model.train()
                        for step in range(1):
                            model.contextualize(s_context)
                            model(*s_inputs, responses=s_targets
                                  )  # model remembers responses

                        # Evaluate the model using the target set
                        model.eval(
                        )  # model now retrieves from examples seen so far
                        model.contextualize(t_context)
                        t_pred = model(*t_inputs)
                        t_loss = model.get_loss(t_pred, t_targets,
                                                t_context).item()
                        metric_reporter.add_batch_stats(task,
                                                        t_loss,
                                                        s_inputs,
                                                        t_predictions=t_pred,
                                                        t_targets=t_targets)

                metric_reporter.report_metric(stage=Stage.TRAIN,
                                              epoch=epoch,
                                              reset=False)

            logging.info("Evaluating model on eval tasks")
            with torch.no_grad():
                for bidx, (support, target,
                           context) in enumerate(eval_task_iters):
                    for (s_inputs,
                         t_inputs), (s_targets,
                                     t_targets), (s_context, t_context) in zip(
                                         support, target, context):
                        task = t_context["task_id"][0]
                        model.train()
                        model.contextualize(s_context)
                        model(*s_inputs,
                              responses=s_targets)  # model remembers responses
                        model.eval()
                        t_pred = model(*t_inputs)
                        t_loss = model.get_loss(t_pred, t_targets,
                                                t_context).item()

                        metric_reporter.add_batch_stats(task,
                                                        t_loss,
                                                        s_inputs,
                                                        t_predictions=t_pred,
                                                        t_targets=t_targets)

            metric_reporter.report_metric(stage=Stage.EVAL,
                                          epoch=epoch,
                                          reset=False)

        best_model_path = os.path.join(train_config.modules_save_dir,
                                       "model.pt")
        torch.save(model.state_dict(), best_model_path)

        return model, None
Ejemplo n.º 6
0
    def train(
        self,
        train_iter: BatchIterator,
        eval_iter: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        optimizers: List[torch.optim.Optimizer],
        scheduler=None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:

        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()
            if cuda_utils.DISTRIBUTED_WORLD_SIZE > 1:
                device_id = torch.cuda.current_device()
                model = DistributedModel(
                    module=model,
                    device_ids=[device_id],
                    output_device=device_id,
                    broadcast_buffers=False,
                )

        best_metric = None
        last_best_epoch = 0
        best_model_path = None
        scheduler = self._prepare_scheduler(train_iter, scheduler)

        def training_pre_batch_callback():
            optimizer_zero_grad(optimizers)

        def training_backprop(loss):
            loss.backward()
            if scheduler:
                scheduler.step_batch()

            if self.config.max_clip_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), self.config.max_clip_norm)
            else:
                grad_norm = None

            optimizer_step(optimizers)
            # grad_norm could be used to check grads sync in distributed training
            return grad_norm

        len_sched_ix = 0

        # Used since we need the infinite iterator (only created and called once)
        def batch_generator_for_epoch(it):
            n = len(it)
            while n > 0:
                yield next(it)
                n -= 1

        for epoch in range(self.config.start_epoch, self.config.epochs + 1):
            # Set the dialogue length in the fields, to be used by the postprocessor
            while self.config.length_schedule_per_epoch \
                    and len_sched_ix < len(self.config.length_schedule_per_epoch) \
                    and epoch >= self.config.length_schedule_per_epoch[len_sched_ix][0]:
                train_iter.max_n_turns = \
                    self.config.length_schedule_per_epoch[len_sched_ix][1]
                eval_iter.max_n_turns = \
                    self.config.length_schedule_per_epoch[len_sched_ix][1]
                len_sched_ix += 1

            LOG.info(f"\nRank {rank} worker: Starting epoch #{epoch}")
            model.train()
            lrs = (str(lr) for lr in learning_rates(optimizers))
            LOG.info(f"Learning rate(s): {', '.join(lrs)}")
            self._run_epoch(
                Stage.TRAIN,
                epoch,
                batch_generator_for_epoch(train_iter),
                model,
                metric_reporter,
                pre_batch=training_pre_batch_callback,
                backprop=training_backprop,
                rank=rank,
            )
            model.eval(Stage.EVAL)
            with torch.no_grad():
                eval_metric = self._run_epoch(
                    Stage.EVAL,
                    epoch,
                    batch_generator_for_epoch(eval_iter),
                    model,
                    metric_reporter,
                    rank=rank)
            # Step the learning rate scheduler(s)
            if scheduler:
                assert eval_metric is not None
                scheduler.step(
                    metrics=metric_reporter.get_model_select_metric(
                        eval_metric),
                    epoch=epoch,
                )

            # choose best model.
            if metric_reporter.compare_metric(eval_metric, best_metric):
                LOG.info(
                    f"Rank {rank} worker: Found a better model! Saving the model state for epoch #{epoch}."
                )
                last_best_epoch = epoch
                best_metric = eval_metric
                # Only rank = 0 trainer saves modules.
                if train_config.save_module_checkpoints and rank == 0:
                    best_model_path = os.path.join(
                        train_config.modules_save_dir, "best_model")
                    optimizer, = optimizers  # PyText only ever returns a single optimizer in this list
                    torch.save(
                        ModelState(
                            epoch=epoch,
                            parameters=model.state_dict(),
                            optimizer=optimizer.state_dict(),
                        ), best_model_path)

            if (self.config.early_stop_after > 0 and
                (epoch - last_best_epoch == self.config.early_stop_after)):
                LOG.info(
                    f"Rank {rank} worker: Eval metric hasn't changed for "
                    f"{self.config.early_stop_after} epochs. Stopping now.")
                break
            sys.stdout.flush()

        train_iter.close()
        eval_iter.close()
        model.load_state_dict(torch.load(best_model_path).parameters)
        return model, best_metric
Ejemplo n.º 7
0
  def train(
      self,
      text_embedder,
      train_task_iters: Optional[BatchPreparationPipeline],
      eval_task_iters: BatchPreparationPipeline,
      model: Model,
      metric_reporter: MetaLearnMetricReporter,
      train_config: PyTextConfig,
      rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:

    diat = text_embedder.decode_ids_as_text
    if cuda_utils.CUDA_ENABLED:
      model = model.cuda()
    best_model_path = None

    # Start outer loop (meta learner "epochs") #############################################
    if not train_task_iters:
      LOG.warning("Model does not need meta-training")
    else:
      for epoch in range(1, 2):  # single epoch
        temp = next(train_task_iters)
        for bidx, (support, target, context) in zip(range(100), train_task_iters):
          for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip(support, target, context):
            # support : (2)
            # s_inputs : (6)
            # s_inputs[0].shape : (128, 3, 38) # 3 means 3 consecutive sentence ## 'denver', 'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?'
            # s_inputs[1].shape : (128, 3, 38, 768) # I guess BertEmbedding
            # s_inputs[2].shape : (128, 2, 37) # 2 means the next consecutive sentence of s_inputs[0] ##  'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?'
            # s_inputs[3].shape : (128) # [3, 3, 3, 3, 3....]
            # s_inputs[4].shape : (128, 3) # each length of sentences in s_inputs[0]
            # s_inputs[5].shape : (128, 2) # each length of sentences in s_inputs[2]
            # s_targets : (2)
            # s_targets[0].shape : (128, 2, 34) ## 'no, the thunderstorm has drifted north .', 'you would like the storm ?'
            # s_targets[1].shape : (128, 2) # each length of sentences in s_targets[0]
            # type(s_context) : dict # keys : {'target_seq_lens', 'orig_text', 'dlg_len', 'dlg_id', 'domain_id', 'task_id', 'index'}
            # s_context['target_seq_lens'].shape : (128, 2) # each length"+1" of sentences in s_targets[0]
            # s_context['orig_text'].__len__() : 128
            # s_context['orig_text'][0]'s original text == "turns": ["Hello how may I help you?", "Is there still supposed to be a thunderstorm today as     there was originally?", "what location?", "Denver", "No, the thunderstorm has drifted north.", "That makes me mad! Why is that?", "You would like the storm?", "Yes! It really upsets me that there isn't goin    g to be one now.", "I'm sorry, I will contact mother nature immediately!", "Why is there not going to be one?", "The radar say so."]
            # s_context['dlg_len'] = 4
            # s_context['dlg_id'] : (128) # '2d1d4ed2', '20debe73', ... ## "id"
            # s_context['domain_id'] : (128) # 'WEATHER_CHECK', 'WEATHER_CHECK'... ## "domain"
            # s_context['task_id'] : (128) # 'd941f2bb', '5f2bb1b2', ... ## "task_id"
            # s_context['index'] : (128) # 25650, 25414, 25454, 25445, 25465, 25370, 25333, 25411, 25203, 25108, 25631, 25532, 25155, 25472, 25365, 25356, 25258, 25282, 25242, 25518, 25150, 25237, 25372

            # t_inputs : (6)
            # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?'
            task = t_context['task_id'][0]

            # Adapt the model using the support set
            model.train()
            for step in range(1):
              model.contextualize(s_context)
              model(*s_inputs, responses=s_targets)  # model remembers responses

            # Evaluate the model using the target set
            model.eval()    # model now retrieves from examples seen so far
            model.contextualize(t_context)
            t_pred = model(*t_inputs)
            t_loss = model.get_loss(t_pred, t_targets, t_context).item()
            metric_reporter.add_batch_stats(task, t_loss, s_inputs,
                                            t_predictions=t_pred, t_targets=t_targets)

        metric_reporter.report_metric(stage=Stage.TRAIN, epoch=epoch, reset=False)

      logging.info("Evaluating model on eval tasks")
      with torch.no_grad():
        for bidx, (support, target, context) in enumerate(eval_task_iters):
          for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip(support, target, context):
            task = t_context["task_id"][0]
            model.train()
            model.contextualize(s_context)
            model(*s_inputs, responses=s_targets)  # model remembers responses
            model.eval()
            t_pred = model(*t_inputs)
            t_loss = model.get_loss(t_pred, t_targets, t_context).item()

            metric_reporter.add_batch_stats(task, t_loss, s_inputs,
                                            t_predictions=t_pred, t_targets=t_targets)

      metric_reporter.report_metric(stage=Stage.EVAL, epoch=epoch, reset=False)

    best_model_path = os.path.join(
        train_config.modules_save_dir, "model.pt"
    )
    torch.save(model.state_dict(), best_model_path)

    return model, None
Ejemplo n.º 8
0
    def train(
            self,
            text_embedder,
            train_task_iters: Optional[BatchPreparationPipeline],    # Optional[X] is equivalent to Union[X, None].
            eval_task_iters: BatchPreparationPipeline,
            model: Model,
            metric_reporter: MetaLearnMetricReporter,
            train_config: PyTextConfig,
            rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:

        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()

        best_model_path = None
        meta_lr = 0.001
        update_lr = 0.01
        from pytorch_transformers import AdamW
        if model.representation.gptmode == 'gpt2':
            meta_optim = AdamW(model.parameters(), lr=meta_lr)
        else:
            meta_optim = OpenAIAdam(model.parameters(), lr=meta_lr)

        # Start outer loop (meta learner "epochs") #############################################
        if not train_task_iters:
            LOG.warning("Model does not need meta-training")
        else:
            logging.info("Training model on train tasks")
            for epoch in range(1, 2):  # single epoch
                for bidx, (support, target, context) in zip(range(100), train_task_iters): # 100 different tasks
                    # support.__len__() : task num
                    #class MetaDataHandler(DialogueDataHandler):
                    #    class Config(DialogueDataHandler.Config):
                    #        # Support set size per task, i.e. base-learner minibatch size
                    #        support_batch_size: int = 64  # 128
                    #        meta_batch_size: int = 4  # 2
                    losses_q = [0 for ]

                    print("support.__len__() ", support.__len__())
                    for enum_i, ((s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context)) in enumerate(zip(support, target, context)): # task num
                        # same task
                        support_set = s_inputs
                        target_set = t_inputs
                        # all same domain
                        # support : (2)
                        # s_inputs : (6)
                        # s_inputs[0].shape : (128, 3, 38) # 3 means 3 consecutive sentence ## 'denver', 'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?'
                        # s_inputs[1].shape : (128, 3, 38, 768) # I guess BertEmbedding ## Now None!!
                        # s_inputs[2].shape : (128, 2, 37) # 2 means the next consecutive sentence of s_inputs[0] ##  'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?'
                        # s_inputs[3].shape : (128) # [3, 3, 3, 3, 3....]
                        # s_inputs[4].shape : (128, 3) # each length of sentences in s_inputs[0]
                        # s_inputs[5].shape : (128, 2) # each length of sentences in s_inputs[2]
                        # s_targets : (2)
                        # s_targets[0].shape : (128, 2, 34) ## 'no, the thunderstorm has drifted north .', 'you would like the storm ?'
                        # s_targets[1].shape : (128, 2) # each length of sentences in s_targets[0]
                        # type(s_context) : dict # keys : {'target_seq_lens', 'orig_text', 'dlg_len', 'dlg_id', 'domain_id', 'task_id', 'index'}
                        # s_context['target_seq_lens'].shape : (128, 2) # each length"+1" of sentences in s_targets[0]
                        # s_context['orig_text'].__len__() : 128
                        # s_context['orig_text'][0]'s original text == "turns": ["Hello how may I help you?", "Is there still supposed to be a thunderstorm today as     there was originally?", "what location?", "Denver", "No, the thunderstorm has drifted north.", "That makes me mad! Why is that?", "You would like the storm?", "Yes! It really upsets me that there isn't goin    g to be one now.", "I'm sorry, I will contact mother nature immediately!", "Why is there not going to be one?", "The radar say so."]
                        # s_context['dlg_len'] = 4
                        # s_context['dlg_id'] : (128) # '2d1d4ed2', '20debe73', ... ## "id"
                        # s_context['domain_id'] : (128) # 'WEATHER_CHECK', 'WEATHER_CHECK'... ## "domain"
                        # s_context['task_id'] : (128) # 'd941f2bb', '5f2bb1b2', ... ## "task_id"
                        # s_context['index'] : (128) # 25650, 25414, 25454, 25445, 25465, 25370, 25333, 25411, 25203, 25108, 25631, 25532, 25155, 25472, 25365, 25356, 25258, 25282, 25242, 25518, 25150, 25237, 25372

                        # t_inputs : (6)
                        # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?'

                        # mldc/data/data_handler.py def _train_input_from_batch(self, batch):
                        # seq_input = getattr(batch, ModelInput.SEQ)  # seq_input (4) # (128, 5, 35), (128) n seqs, (128, 5) n words per seq, None
                        # target = getattr(batch, ModelOutput.TOK)  # (2) (128, 48), (128)
                        # teacher_forcing_input, teacher_forcing_lens = self._make_teacher_forcing(*target)
                        # return (# flatten the seq input into the list of parameters
                        #   seq_input[0],  # (128, 5, 35)
                        #   seq_input[3],  # None
                        #   teacher_forcing_input,
                        #   seq_input[1],  # n seqs
                        #   seq_input[2],  # n words per seq
                        #   teacher_forcing_lens,  # n words per output seq

                        diat = text_embedder.decode_ids_as_text
                        task = t_context['task_id'][0]
                        s_domain = s_context['domain_id'][0]
                        #t_domain = t_context['domain_id'][0]
                        print("b_idx", bidx, "enum_i", enum_i,"s_domain :", s_domain)
                        #print("t_domain :", s_domain)
                        #print("task :", task)
                        # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?'
                        # inputs input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids

                        # TODO
                        num_instance = support_set[0].shape[0]
                        # Adapt the model usingthe support set
                        model.train()
                        #spt_input_ids, spt_mc_token_ids, spt_lm_labels, spt_mc_labels, spt_token_type_ids = support_set
                        #for s_idx, (sii, smti, sll, sml, stti) in enumerate(zip(spt_input_ids, spt_mc_token_ids,
                        #                                                        spt_lm_labels, spt_mc_labels,
                        #                                                        spt_token_type_ids)):
                        for s_idx, support_ins in enumerate(zip(*support_set)):
                            sii, smti, sll, sml, stti = support_ins
                            if model.representation.gptmode == "gpt2":
                                lm_loss, mc_loss, _, _, _ = model(*support_ins)
                            else:
                                lm_loss, mc_loss = model(*support_ins)
                            loss = (lm_loss * 2 + mc_loss * 1)
                            grad = torch.autograd.grad(loss, model.parameters())
                            fast_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, model.parameters())))








                            ## input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,

                        #task_num = s_inputs.shape[0] # batchsz
                        #for task_idx in range(task_num):
                        #  s_inputs_task = s_inputs[task_idx]

                        # Adapt the model using the support set
                        # model.train()
                        # for step in range(1):
                        #   #model.contextualize(s_context)
                        #   #model(*s_inputs, responses=s_targets)  # model remembers responses
                        #   lm_loss, mc_loss, _, _, _ = model(*s_inputs)

                        # # Evaluate the model using the target set
                        # model.eval()    # model now retrieves from examples seen so far
                        # model.contextualize(t_context)
                        # t_pred = model(*t_inputs)
                        # t_loss = model.get_loss(t_pred, t_targets, t_context).item()
                        # metric_reporter.add_batch_stats(task, t_loss, s_inputs,
                        #                                 t_predictions=t_pred, t_targets=t_targets)

                metric_reporter.report_metric(stage=Stage.TRAIN, epoch=epoch, reset=False)

            import ipdb; ipdb.set_trace()
            logging.info("Evaluating model on eval tasks")
            with torch.no_grad():
                for bidx, (support, target, context) in enumerate(eval_task_iters):
                    for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip(support, target, context):
                        task = t_context["task_id"][0]
                        model.train()
                        model.contextualize(s_context)
                        model(*s_inputs, responses=s_targets)  # model remembers responses
                        model.eval()
                        t_pred = model(*t_inputs)
                        t_loss = model.get_loss(t_pred, t_targets, t_context).item()

                        metric_reporter.add_batch_stats(task, t_loss, s_inputs,
                                                        t_predictions=t_pred, t_targets=t_targets)

            metric_reporter.report_metric(stage=Stage.EVAL, epoch=epoch, reset=False)

        best_model_path = os.path.join(
            train_config.modules_save_dir, "model.pt"
        )
        torch.save(model.state_dict(), best_model_path)

        return model, None