Example #1
0
    def train(self,
              train_iter: Iterator,
              eval_iter: Iterator,
              model: Model,
              metric_reporter: MetricReporter,
              optimizer: torch.optim.Optimizer,
              pytext_config: PyTextConfig,
              scheduler=None,
              *args,
              **kwargs) -> Tuple[torch.nn.Module, Any]:
        print("Num of workers for Hogwild Training is {}".format(
            self.num_workers))

        # Share memory of tensors for concurrent updates from multiple processes.
        if self.num_workers > 1:
            for param in model.parameters():
                param.share_memory_()

        return super().train(
            train_iter,
            eval_iter,
            model,
            metric_reporter,
            optimizer,
            pytext_config,
            scheduler,
        )
Example #2
0
    def train(self,
              train_iter: Iterator,
              eval_iter: Iterator,
              model: Model,
              metric_reporter: MetricReporter,
              optimizers: List[torch.optim.Optimizer],
              pytext_config: PyTextConfig,
              scheduler=None,
              *args,
              **kwargs):
        print("Num of workers for Hogwild Training is {}".format(
            self.num_workers))

        # Share memory of tensors for concurrent updates from multiple processes.
        if self.num_workers > 1:
            for param in model.parameters():
                param.share_memory_()

        processes = []
        for rank in range(1, self.num_workers):
            # Initialize the batches with different randome states.
            train_iter.batches.init_epoch()
            p = mp.Process(
                target=self.real_trainer.train,
                args=(
                    train_iter,
                    eval_iter,
                    model,
                    metric_reporter,
                    optimizers,
                    pytext_config,
                    scheduler,
                    None,
                    rank,
                ),
            )
            processes.append(p)
            p.start()

        training_result: List = Manager().list()  # Actual type is ListProxy.
        self.real_trainer.train(
            train_iter,
            eval_iter,
            model,
            metric_reporter,
            optimizers,
            pytext_config,
            scheduler,
            training_result,
            rank=0,
        )

        for p in processes:
            p.join()

        # Ony rank 0 worker writes to training_result
        assert len(training_result) == 1
        return training_result[0]  # Contains best model and best metric.
Example #3
0
 def apply_masks(self, model: Model, masks: List[torch.Tensor]):
     """
     apply given masks to zero-out learnable weights in model
     """
     learnableparams = [p for p in model.parameters() if p.requires_grad]
     assert len(learnableparams) == len(masks)
     for m, w in zip(masks, learnableparams):
         assert m.size() == w.size()
         w.data *= m.clone()
Example #4
0
 def apply_masks(self, model: Model, masks: List[torch.Tensor]):
     """
     apply given masks to zero-out learnable weights in model
     """
     learnableparams = [p for p in model.parameters() if p.requires_grad]
     assert len(learnableparams) == len(masks)
     for m, w in zip(masks, learnableparams):
         if len(m.size()):
             assert m.size() == w.size()
             w.data *= m.clone()
             # if accumulate_mask, remove a param permanently by also removing
             # its gradient
             if self.accumulate_mask:
                 w.grad.data *= m.clone()
Example #5
0
 def get_current_sparsity(self, model: Model) -> float:
     trainable_params = sum(p.numel() for p in model.parameters()
                            if p.requires_grad)
     nonzero_params = sum(p.nonzero().size(0) for p in model.parameters()
                          if p.requires_grad)
     return (trainable_params - nonzero_params) / trainable_params
Example #6
0
    def get_masks(self,
                  model: Model,
                  pre_masks: List[torch.Tensor] = None) -> List[torch.Tensor]:
        """
        Note: this function returns the masks only but do not sparsify or modify the
        weights

        prune x% of weights among the weights with "1" in pre_masks

        Args:
            model: Model
            pre_masks: list of FloatTensors where "1" means retained the weight and
             "0" means pruned the weight

        Return:
            masks: List[torch.Tensor], intersection of new masks and pre_masks, so
            that "1" only if the weight is selected after new masking and pre_mask
        """
        learnableparams = [p for p in model.parameters() if p.requires_grad]
        if pre_masks:
            self._masks = pre_masks
        if self._masks is None:
            # retain everything if no pre_masks given
            self._masks = [torch.ones_like(p) for p in learnableparams]

        assert len(learnableparams) == len(self._masks)
        for m, w in zip(self._masks, learnableparams):
            if len(m.size()):
                assert m.size() == w.size()

        if self.layerwise_pruning:
            masks = []
            for m, param in zip(self._masks, learnableparams):
                weights_abs = torch.abs(param.data).to(param.device)
                # absolute value of weights selected from existent masks
                weights_abs_masked_flat = torch.flatten(weights_abs[m.bool()])
                total_size = weights_abs_masked_flat.numel()
                if total_size > 0:
                    # using ceil instead of floor() or int()
                    # because at least one element in the tensor required to be selected
                    max_num_nonzeros = math.ceil(total_size *
                                                 (1 - self.sparsity))
                    # only pruned among the weights slected from existent masks
                    topkval = (torch.topk(
                        weights_abs_masked_flat,
                        max_num_nonzeros).values.min().item())
                    # intersection of the new mask and pre_mexistent masks,
                    # mask == 1 retain, mask == 0 pruned,
                    mask = (weights_abs >= topkval).float() * m
                else:
                    mask = param.new_empty(())
                masks.append(mask)
        else:
            # concatenated flatten tensor of learnableparams that have _masks as True
            learnableparams_masked_flat = torch.cat(
                [
                    torch.flatten(p[m.bool()])
                    for m, p in zip(self._masks, learnableparams)
                ],
                dim=0,
            )
            # using ceil instead of floor() or int() because at least one element
            # in the tensor required to be selected
            max_num_nonzeros = math.ceil(learnableparams_masked_flat.numel() *
                                         (1 - self.sparsity))
            # select globally the top-k th weight among weights selected from _masks
            topkval = (torch.topk(torch.abs(learnableparams_masked_flat),
                                  max_num_nonzeros).values.min().item())
            # intersection of the new mask and _masks,
            # mask == 1 retain, mask == 0 pruned,
            masks = [(torch.abs(p.data) >= topkval).float() *
                     m if p.numel() > 0 else p.new_empty(())
                     for m, p in zip(self._masks, learnableparams)]

        if self.accumulate_mask:
            self._masks = masks

        return masks
Example #7
0
 def get_sparsifiable_params(self, model: Model):
     sparsifiable_params = [
         p for p in model.parameters() if p.requires_grad
     ]
     return sparsifiable_params
Example #8
0
    def train(
            self,
            text_embedder,
            train_task_iters: Optional[BatchPreparationPipeline],    # Optional[X] is equivalent to Union[X, None].
            eval_task_iters: BatchPreparationPipeline,
            model: Model,
            metric_reporter: MetaLearnMetricReporter,
            train_config: PyTextConfig,
            rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:

        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()

        best_model_path = None
        meta_lr = 0.001
        update_lr = 0.01
        from pytorch_transformers import AdamW
        if model.representation.gptmode == 'gpt2':
            meta_optim = AdamW(model.parameters(), lr=meta_lr)
        else:
            meta_optim = OpenAIAdam(model.parameters(), lr=meta_lr)

        # Start outer loop (meta learner "epochs") #############################################
        if not train_task_iters:
            LOG.warning("Model does not need meta-training")
        else:
            logging.info("Training model on train tasks")
            for epoch in range(1, 2):  # single epoch
                for bidx, (support, target, context) in zip(range(100), train_task_iters): # 100 different tasks
                    # support.__len__() : task num
                    #class MetaDataHandler(DialogueDataHandler):
                    #    class Config(DialogueDataHandler.Config):
                    #        # Support set size per task, i.e. base-learner minibatch size
                    #        support_batch_size: int = 64  # 128
                    #        meta_batch_size: int = 4  # 2
                    losses_q = [0 for ]

                    print("support.__len__() ", support.__len__())
                    for enum_i, ((s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context)) in enumerate(zip(support, target, context)): # task num
                        # same task
                        support_set = s_inputs
                        target_set = t_inputs
                        # all same domain
                        # support : (2)
                        # s_inputs : (6)
                        # s_inputs[0].shape : (128, 3, 38) # 3 means 3 consecutive sentence ## 'denver', 'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?'
                        # s_inputs[1].shape : (128, 3, 38, 768) # I guess BertEmbedding ## Now None!!
                        # s_inputs[2].shape : (128, 2, 37) # 2 means the next consecutive sentence of s_inputs[0] ##  'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?'
                        # s_inputs[3].shape : (128) # [3, 3, 3, 3, 3....]
                        # s_inputs[4].shape : (128, 3) # each length of sentences in s_inputs[0]
                        # s_inputs[5].shape : (128, 2) # each length of sentences in s_inputs[2]
                        # s_targets : (2)
                        # s_targets[0].shape : (128, 2, 34) ## 'no, the thunderstorm has drifted north .', 'you would like the storm ?'
                        # s_targets[1].shape : (128, 2) # each length of sentences in s_targets[0]
                        # type(s_context) : dict # keys : {'target_seq_lens', 'orig_text', 'dlg_len', 'dlg_id', 'domain_id', 'task_id', 'index'}
                        # s_context['target_seq_lens'].shape : (128, 2) # each length"+1" of sentences in s_targets[0]
                        # s_context['orig_text'].__len__() : 128
                        # s_context['orig_text'][0]'s original text == "turns": ["Hello how may I help you?", "Is there still supposed to be a thunderstorm today as     there was originally?", "what location?", "Denver", "No, the thunderstorm has drifted north.", "That makes me mad! Why is that?", "You would like the storm?", "Yes! It really upsets me that there isn't goin    g to be one now.", "I'm sorry, I will contact mother nature immediately!", "Why is there not going to be one?", "The radar say so."]
                        # s_context['dlg_len'] = 4
                        # s_context['dlg_id'] : (128) # '2d1d4ed2', '20debe73', ... ## "id"
                        # s_context['domain_id'] : (128) # 'WEATHER_CHECK', 'WEATHER_CHECK'... ## "domain"
                        # s_context['task_id'] : (128) # 'd941f2bb', '5f2bb1b2', ... ## "task_id"
                        # s_context['index'] : (128) # 25650, 25414, 25454, 25445, 25465, 25370, 25333, 25411, 25203, 25108, 25631, 25532, 25155, 25472, 25365, 25356, 25258, 25282, 25242, 25518, 25150, 25237, 25372

                        # t_inputs : (6)
                        # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?'

                        # mldc/data/data_handler.py def _train_input_from_batch(self, batch):
                        # seq_input = getattr(batch, ModelInput.SEQ)  # seq_input (4) # (128, 5, 35), (128) n seqs, (128, 5) n words per seq, None
                        # target = getattr(batch, ModelOutput.TOK)  # (2) (128, 48), (128)
                        # teacher_forcing_input, teacher_forcing_lens = self._make_teacher_forcing(*target)
                        # return (# flatten the seq input into the list of parameters
                        #   seq_input[0],  # (128, 5, 35)
                        #   seq_input[3],  # None
                        #   teacher_forcing_input,
                        #   seq_input[1],  # n seqs
                        #   seq_input[2],  # n words per seq
                        #   teacher_forcing_lens,  # n words per output seq

                        diat = text_embedder.decode_ids_as_text
                        task = t_context['task_id'][0]
                        s_domain = s_context['domain_id'][0]
                        #t_domain = t_context['domain_id'][0]
                        print("b_idx", bidx, "enum_i", enum_i,"s_domain :", s_domain)
                        #print("t_domain :", s_domain)
                        #print("task :", task)
                        # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?'
                        # inputs input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids

                        # TODO
                        num_instance = support_set[0].shape[0]
                        # Adapt the model usingthe support set
                        model.train()
                        #spt_input_ids, spt_mc_token_ids, spt_lm_labels, spt_mc_labels, spt_token_type_ids = support_set
                        #for s_idx, (sii, smti, sll, sml, stti) in enumerate(zip(spt_input_ids, spt_mc_token_ids,
                        #                                                        spt_lm_labels, spt_mc_labels,
                        #                                                        spt_token_type_ids)):
                        for s_idx, support_ins in enumerate(zip(*support_set)):
                            sii, smti, sll, sml, stti = support_ins
                            if model.representation.gptmode == "gpt2":
                                lm_loss, mc_loss, _, _, _ = model(*support_ins)
                            else:
                                lm_loss, mc_loss = model(*support_ins)
                            loss = (lm_loss * 2 + mc_loss * 1)
                            grad = torch.autograd.grad(loss, model.parameters())
                            fast_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, model.parameters())))








                            ## input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,

                        #task_num = s_inputs.shape[0] # batchsz
                        #for task_idx in range(task_num):
                        #  s_inputs_task = s_inputs[task_idx]

                        # Adapt the model using the support set
                        # model.train()
                        # for step in range(1):
                        #   #model.contextualize(s_context)
                        #   #model(*s_inputs, responses=s_targets)  # model remembers responses
                        #   lm_loss, mc_loss, _, _, _ = model(*s_inputs)

                        # # Evaluate the model using the target set
                        # model.eval()    # model now retrieves from examples seen so far
                        # model.contextualize(t_context)
                        # t_pred = model(*t_inputs)
                        # t_loss = model.get_loss(t_pred, t_targets, t_context).item()
                        # metric_reporter.add_batch_stats(task, t_loss, s_inputs,
                        #                                 t_predictions=t_pred, t_targets=t_targets)

                metric_reporter.report_metric(stage=Stage.TRAIN, epoch=epoch, reset=False)

            import ipdb; ipdb.set_trace()
            logging.info("Evaluating model on eval tasks")
            with torch.no_grad():
                for bidx, (support, target, context) in enumerate(eval_task_iters):
                    for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip(support, target, context):
                        task = t_context["task_id"][0]
                        model.train()
                        model.contextualize(s_context)
                        model(*s_inputs, responses=s_targets)  # model remembers responses
                        model.eval()
                        t_pred = model(*t_inputs)
                        t_loss = model.get_loss(t_pred, t_targets, t_context).item()

                        metric_reporter.add_batch_stats(task, t_loss, s_inputs,
                                                        t_predictions=t_pred, t_targets=t_targets)

            metric_reporter.report_metric(stage=Stage.EVAL, epoch=epoch, reset=False)

        best_model_path = os.path.join(
            train_config.modules_save_dir, "model.pt"
        )
        torch.save(model.state_dict(), best_model_path)

        return model, None