Exemplo n.º 1
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    args.warmup_steps = t_total // 100

    # Prepare optimizer and schedule (linear warmup and decay)
    optimizer_grouped_parameters = get_param_groups(args, model)
    optimizer = RAdam(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    args.logging_steps = len(train_dataloader) // 1
    args.save_steps = args.logging_steps
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)
    for _ in train_iterator:
        args.current_epoch = _
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids':
                batch[0],
                'attention_mask':
                batch[1],
                'token_type_ids':
                batch[2] if args.model_type in ['bert', 'xlnet'] else None,
            }  # XLM and RoBERTa don't use segment_ids
            #   'labels':         batch[3]}
            outputs = model(**inputs)
            outputs = [outputs[i][0] for i in range(len(outputs))]

            loss_fct = CrossEntropyLoss()
            loss_fct = DataParallelCriterion(loss_fct)

            loss = loss_fct(outputs, batch[3])

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
class TrainOperator:
    def __init__(self):
        # source
        self.tok = sp.SentencePieceProcessor()
        self.tok.Load(config.tok_path)
        self.vocab = self.tok.GetPieceSize()
        self.pad = self.tok.piece_to_id('[PAD]')
        self.num_workers = multiprocessing.cpu_count()

        self.cuda = config.cuda and torch.cuda.is_available()

        # for data parallel
        if self.cuda:
            self.n_gpu = torch.cuda.device_count()
        else:
            self.n_gpu = 0

        # load loader
        self.train_loader = self._construct_loader('train')
        self.dev_loader = self._construct_loader('dev')
        print('* Train Operator is loaded')

    def setup_train(self, model_path=None):
        self.loss_weight = torch.FloatTensor([1 - config.alpha, config.alpha])
        if self.cuda:
            self.loss_weight = self.loss_weight.cuda()

        self.model = HierSumTransformer(self.vocab, config.emb_dim,
                                        config.d_model, config.N, config.heads,
                                        config.max_sent_len,
                                        config.max_doc_len)

        if model_path:
            self.model.load_state_dict(
                torch.load(model_path,
                           map_location=lambda storage, location: storage))
        else:
            for p in self.model.parameters():
                if p.dim() > 1:
                    torch.nn.init.xavier_uniform_(p)
        # Data Parallel
        if self.cuda:
            if self.n_gpu == 1:
                pass
            elif self.n_gpu > 1:
                self.model = torch.nn.DataParallel(self.model)
            self.model = self.model.cuda()

        #self.optim = torch.optim.Adam(self.model.parameters(), lr=config.lr, betas=(0.9, 0.98), eps=1e-9)
        self.optim = RAdam(self.model.parameters(),
                           lr=config.lr,
                           betas=(0.9, 0.98),
                           eps=1e-9)

        print('* Training model is prepared')

    def train(self):
        printInterval = 20
        init_loss = 1e5
        init_f1 = 0

        for n in range(config.n_epoch):
            loss_tr_total = 0
            for batch_id, batch in enumerate(self.train_loader):
                loss_tr = self._train_one_batch(batch)
                loss_tr_total += loss_tr
                if (batch_id + 1) % printInterval == 0 or batch_id == 0:
                    loss_tr = round(loss_tr, 4)
                    loss_eval, recall, pre, f1 = [
                        round(l, 4) for l in self._evaluate()
                    ]
                    print(
                        "| epoch: {} | batch: {}/{}| tr_loss: {} | val_loss: {} |"
                        .format(n + 1, batch_id + 1, len(self.train_loader),
                                round(loss_tr_total / (batch_id + 1), 4),
                                loss_eval))
                    print("| epoch: {} |Recall: {} | Precision: {} | F1: {} |".
                          format(n + 1, recall, pre, f1))
                    print("-" * 100)

                    if loss_eval < init_loss:
                        init_loss = loss_eval

                        if self.n_gpu <= 1:
                            torch.save(self.model.state_dict(),
                                       './resource/RNN_TR_HiSum_v2.0.pkl')
                        elif self.n_gpu > 1:
                            torch.save(self.model.module.state_dict(),
                                       './resource/RNN_TR_HiSum_v2.0.pkl')

                    # change model state to train
                    self.model.train()

    def _construct_loader(self, type):
        dataset = ExtSumDataset(config.data_path, self.tok, type)
        loader = DataLoader(
            dataset,
            batch_size=config.batch_size,
            num_workers=0,
        )
        return loader

    def _train_one_batch(self, batch):
        doc_id, doc_len, sent_len, label, doc_mask = batch
        doc_mask = doc_mask.unsqueeze(1)
        sent_mask = torch.stack([self._create_mask(sent) for sent in doc_id])

        if self.cuda:
            doc_id = doc_id.cuda()
            doc_len = doc_len.cuda()
            doc_mask = doc_mask.cuda()
            sent_mask = sent_mask.cuda()
            sent_len = sent_len.cuda()
            label = label.cuda()

        preds = self.model(doc_id, sent_mask, doc_mask, sent_len)
        loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                               label.reshape(-1),
                               ignore_index=config.ignore_index_ext,
                               weight=self.loss_weight)
        #loss = focal_loss(preds.view(-1, preds.size(-1)), label.reshape(-1), ignore_index=config.ignore_index_ext,  alpha = config.alpha, gamma = config.gamma)

        loss.backward()
        self.optim.step()
        return loss.tolist()

    def _evaluate(self):
        right = 0
        origin = 0
        found = 0
        total_loss = 0
        self.model.eval()

        for i, data in enumerate(self.dev_loader):
            doc_id, doc_len, sent_len, label, doc_mask = data
            doc_mask = doc_mask.unsqueeze(1)
            sent_mask = torch.stack(
                [self._create_mask(sent) for sent in doc_id])

            if self.cuda:
                doc_id = doc_id.cuda()
                doc_len = doc_len.cuda()
                doc_mask = doc_mask.cuda()
                sent_mask = sent_mask.cuda()
                sent_len = sent_len.cuda()
                label = label.cuda()

            preds = self.model(doc_id, sent_mask, doc_mask, sent_len)
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                   label.reshape(-1),
                                   ignore_index=config.ignore_index_ext,
                                   weight=self.loss_weight)
            #loss = focal_loss(preds.view(-1, preds.size(-1)), label.reshape(-1), ignore_index=config.ignore_index_ext, alpha = config.alpha, gamma = config.gamma)

            total_loss += loss.tolist()

            pred_label = [torch.argmax(p, 1).tolist() for p in preds]
            labels = label.tolist()

            for p_tag, label in zip(pred_label, labels):
                for p, l in zip(p_tag, label):
                    if l == config.ignore_index_ext:
                        break
                    elif p == 1 and l == 1:
                        right += 1
                        origin += 1
                        found += 1
                    elif p == 0 and l == 1:
                        origin += 1
                    elif p == 1 and l == 0:
                        found += 1
                    else:
                        pass

        recall = (right / (origin + 1e-5))
        precision = (right / (found + 1e-5))
        f1 = (2 * precision * recall) / (precision + recall + 1e-5)
        return round(total_loss / (i + 1),
                     4), round(recall, 4), round(precision, 4), round(f1, 4)

    def _create_mask(self, tok_ids):
        mask = (tok_ids != self.pad).unsqueeze(1)
        return mask
Exemplo n.º 3
0
class GPT2LM(BaseModel):
    def __init__(self, param):
        args = param.args
        net = Network(param)
        self.optimizer = RAdam(net.get_parameters_by_name(), lr=args.lr)
        optimizerList = {"optimizer": self.optimizer}
        checkpoint_manager = CheckpointManager(args.name, args.model_dir, \
            args.checkpoint_steps, args.checkpoint_max_to_keep, "min")
        super().__init__(param, net, optimizerList, checkpoint_manager)

        self.create_summary()

    def create_summary(self):
        args = self.param.args
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), \
          args)
        self.trainSummary = self.summaryHelper.addGroup(\
         scalar=["loss", "word_loss", "perplexity"],\
         prefix="train")

        scalarlist = ["word_loss", "perplexity_avg_on_batch"]
        tensorlist = []
        textlist = []
        emblist = []
        for i in self.args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(\
         scalar=scalarlist,\
         tensor=tensorlist,\
         text=textlist,\
         embedding=emblist,\
         prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(\
         scalar=scalarlist,\
         tensor=tensorlist,\
         text=textlist,\
         embedding=emblist,\
         prefix="test")

    def _preprocess_batch(self, data):
        incoming = Storage()
        incoming.data = data = Storage(data)
        data.batch_size = data.sent.shape[0]
        data.sent = cuda(torch.LongTensor(data.sent))  # length * batch_size
        data.sent_attnmask = zeros(*data.sent.shape)
        for i, length in enumerate(data.sent_length):
            data.sent_attnmask[i, :length] = 1
        return incoming

    def get_next_batch(self, dm, key, restart=True):
        data = dm.get_next_batch(key)
        if data is None:
            if restart:
                dm.restart(key)
                return self.get_next_batch(dm, key, False)
            else:
                return None
        return self._preprocess_batch(data)

    def get_batches(self, dm, key):
        batches = list(
            dm.get_batches(key, batch_size=self.args.batch_size,
                           shuffle=False))
        return len(batches), (self._preprocess_batch(data) for data in batches)

    def get_select_batch(self, dm, key, i):
        data = dm.get_batch(key, i)
        if data is None:
            return None
        return self._preprocess_batch(data)

    def train(self, batch_num):
        args = self.param.args
        dm = self.param.volatile.dm
        datakey = 'train'

        for i in range(batch_num):
            self.now_batch += 1
            incoming = self.get_next_batch(dm, datakey)
            incoming.args = Storage()

            if (i + 1) % args.batch_num_per_gradient == 0:
                self.zero_grad()
            self.net.forward(incoming)

            loss = incoming.result.loss
            self.trainSummary(self.now_batch, storage_to_list(incoming.result))
            logging.info("batch %d : gen loss=%f", self.now_batch,
                         loss.detach().cpu().numpy())

            loss.backward()

            if (i + 1) % args.batch_num_per_gradient == 0:
                nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip)
                self.optimizer.step()

    def evaluate(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        dm.restart(key, args.batch_size, shuffle=False)

        result_arr = []
        while True:
            incoming = self.get_next_batch(dm, key, restart=False)
            if incoming is None:
                break
            incoming.args = Storage()

            with torch.no_grad():
                self.net.forward(incoming)
            result_arr.append(incoming.result)

        detail_arr = Storage()
        for i in args.show_sample:
            index = [i * args.batch_size + j for j in range(args.batch_size)]
            incoming = self.get_select_batch(dm, key, index)
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            detail_arr["show_str%d" % i] = incoming.result.show_str

        detail_arr.update(
            {key: get_mean(result_arr, key)
             for key in result_arr[0]})
        detail_arr.perplexity_avg_on_batch = np.exp(detail_arr.word_loss)
        return detail_arr

    def train_process(self):
        args = self.param.args
        dm = self.param.volatile.dm

        while self.now_epoch < args.epochs:
            self.now_epoch += 1
            self.updateOtherWeights()

            dm.restart('train', args.batch_size)
            self.net.train()
            self.train(args.batch_per_epoch)

            self.net.eval()
            devloss_detail = self.evaluate("dev")
            self.devSummary(self.now_batch, devloss_detail)
            logging.info("epoch %d, evaluate dev", self.now_epoch)

            testloss_detail = self.evaluate("test")
            self.testSummary(self.now_batch, testloss_detail)
            logging.info("epoch %d, evaluate test", self.now_epoch)

            self.save_checkpoint(value=devloss_detail["word_loss"].tolist())

    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        metric1 = dm.get_teacher_forcing_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval teacher-forcing")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
                gen_log_prob = nn.functional.log_softmax(incoming.gen.w, -1)
            data = incoming.data
            data.sent_allvocabs = LongTensor(incoming.data.sent_allvocabs)
            data.sent_length = incoming.data.sent_length
            data.gen_log_prob = gen_log_prob
            metric1.forward(data)
        res = metric1.close()

        metric2 = dm.get_inference_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval free-run")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            data = incoming.data
            data.gen = incoming.gen.w_o.detach().cpu().numpy()
            metric2.forward(data)
        res.update(metric2.close())

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with open(filename, 'w') as f:
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, str):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            for i in range(len(res['gen'])):
                f.write("gen:\t%s\n" % " ".join(res['gen'][i]))
            f.flush()
        logging.info("result output to %s.", filename)
        return {
            key: val
            for key, val in res.items() if isinstance(val, (str, int, float))
        }

    def test_process(self):
        logging.info("Test Start.")
        self.net.eval()
        self.test("train")
        self.test("dev")
        test_res = self.test("test")
        logging.info("Test Finish.")
        return test_res
Exemplo n.º 4
0
class RRR(object):

    def __init__(self,model,args,dataset,network,clipgrad=10000):
        self.args=args
        self.dataset = dataset
        self.model=model
        self.nepochs=args.train.nepochs
        self.sbatch=args.train.batch_size

        self.lrs = [args.train.lr for _ in range(args.experiment.ntasks)]
        if self.args.experiment.fscil:
            self.lrs[1:] = [self.args.experiment.lr_multiplier * lr for lr in self.lrs[1:]]


        # self.lrs = [item[1] for item in args.lrs]
        self.lrs_exp = [args.saliency.lr for _ in range(args.experiment.ntasks)]
        self.lr_min=[lr/1000. for lr in self.lrs]
        self.lr_factor=args.train.lr_factor
        self.lr_patience=args.train.lr_patience
        self.clipgrad=clipgrad
        self.checkpoint=args.path.checkpoint
        self.device=args.device.name

        # self.args.train.schedule = [20, 30,40]
        self.args.train.schedule = [20, 40, 60]

        self.criterion=torch.nn.CrossEntropyLoss().to(device=self.args.device.name)

        if self.args.saliency.loss == "l1":
            self.sal_loss = torch.nn.L1Loss().to(device=self.args.device.name)
        elif self.args.saliency.loss == "l2":
            self.sal_loss = torch.nn.MSELoss().to(device=self.args.device.name)
        else:
            raise NotImplementedError

        if self.args.train.l1_reg:
            self.l1_reg = torch.nn.L1Loss(reduction='sum')

        self.get_optimizer(task_id=0)
        self.get_optimizer_explanations(task_id=0)



        self.network=network
        self.inputsize=args.inputsize
        self.taskcla=args.taskcla

        self.memory_loaders = {}
        self.test_loader = {}


        self.memory_paths = []
        self.saliency_loaders = None

        if self.args.experiment.raw_memory_only or self.args.experiment.xai_memory:
            self.use_memory = True
        else:
            self.use_memory = False

        # XAI
        if self.args.saliency.method == 'gc':
            print ("Using GradCAM to obtain saliency maps")
            from approaches.explanations import GradCAM as Explain
        elif self.args.saliency.method == 'smooth':
            print ("Using SmoothGrad to obtain saliency maps")
            from approaches.explanations import SmoothGrad as Explain
        elif self.args.saliency.method == 'bp':
            print ("Using BackPropagation to obtain saliency maps")
            from approaches.explanations import BackPropagation as Explain
        elif self.args.saliency.method == 'gbp':
            print ("Using Guided BackPropagation to obtain saliency maps")
            from approaches.explanations import GuidedBackPropagation as Explain
        elif self.args.saliency.method == 'deconv':
            from approaches.explanations import Deconvnet as Explain

        self.explainer = Explain(self.args)


    def get_optimizer(self,task_id, lr=None):
        if lr is None: lr=self.lrs[task_id]

        if (self.args.train.optimizer=="radam"):
            self.optimizer = RAdam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0)
        elif(self.args.train.optimizer=="adam"):
            self.optimizer= torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08,
                               weight_decay=0.0, amsgrad=False)
        elif(self.args.train.optimizer=="sgd"):
            self.optimizer= torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9, weight_decay=0.001)
            self.scheduler_opt = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                        patience=self.lr_patience,
                                                                        factor=self.lr_factor / 10,
                                                                        min_lr=self.lr_min[task_id], verbose=True)



    def adjust_learning_rate(self, epoch):
        if epoch in self.args.train.schedule:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] *= self.args.train.gamma
            print("Reducing learning rate to ", param_group['lr'])


    def get_optimizer_explanations(self, task_id, lr=None):
        if lr is None: lr=self.lrs_exp[task_id]

        if self.args.train.optimizer=="sgd":
            self.optimizer_explanations = torch.optim.SGD(self.model.parameters(), lr=lr, weight_decay=self.args.train.wd)
            self.scheduler_exp_opt = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer_explanations,
                                                                         patience=self.lr_patience,
                                                                          factor=self.lr_factor/10,
                                                                          min_lr=self.lr_min[task_id], verbose=True)
        elif(self.args.train.optimizer=="adam"):
            self.optimizer_explanations= torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08,
                               weight_decay=0.0, amsgrad=False)

        elif (self.args.train.optimizer=="radam"):
            self.optimizer_explanations = RAdam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0)




    def train(self,task_id,performance):
        # if task_id > 0 and not self.args.architecture.multi_head:
        #     self.model.increment_classes(task_id)
        #     self.model  = self.model.to(self.device)

        print('*'*100)
        data_loader = self.dataset.get(task_id=task_id)
        self.test_loader[task_id] = data_loader['test']

        print(' '*85, 'Run #{:2d} - Dataset {:2d} ({:s})'.format(self.args.seed+1, task_id+1,data_loader['name']))

        print('*'*100)

        best_loss=np.inf
        if self.args.train.nepochs > 1:
            best_model=utils.get_model(self.model)
        lr = self.lrs[task_id]
        lr_exp = self.lrs_exp[task_id]
        patience=self.lr_patience
        best_acc = 0

        self.get_optimizer(task_id=task_id)
        self.get_optimizer_explanations(task_id=task_id)

        # Loop epochs

        for e in range(self.nepochs):
            self.epoch = e

            # Train
            clock0=time.time()
            self.train_epoch(data_loader['train'], task_id)
            clock1=time.time()

            # Valid
            if self.args.train.pc_valid > 0:
                valid_res = self.eval(data_loader['valid'], task_id, set_name='valid')
                utils.report_valid(valid_res)
            else:
                valid_res = self.eval(data_loader['test'], task_id, set_name='test')
                print ("Epoch {}/{} | Test acc {}: {:.2f}".format(e+1, self.nepochs, task_id, valid_res['acc']))

            if self.args.wandb.log:
                wandb.log({"Best Acc on Task {}".format(task_id): valid_res['acc']})

            if (self.args.optimizer == "sgd"):
                self.scheduler_opt.step(valid_res['loss'])
                self.scheduler_exp_opt.step(valid_res['loss'])

            is_best = valid_res['acc'] > best_acc
            if is_best:
                best_model = utils.get_model(self.model)
            best_acc = max(valid_res['acc'], best_acc)



        # Restore best validation model
        if self.args.train.nepochs > 1:
            self.model.load_state_dict(deepcopy(best_model))

        self.save_model(task_id, deepcopy(self.model.state_dict()))

        if task_id == 1 and self.args.experiment.xai_memory:
            self.compute_memory()

        if self.use_memory and task_id < self.args.experiment.ntasks:
            self.update_memory(task_id)


    def train_epoch(self,train_loader,task_id):
        self.adjust_learning_rate(self.epoch)

        self.model.train()

        if task_id > 0 and self.args.experiment.xai_memory:
            for idx, (data, target, sal, tt, _) in enumerate(self.saliency_loaders):

                x = data.to(device=self.device, dtype=torch.float)
                s = sal.to(device=self.device, dtype=torch.float)

                explanations, self.model , _, _ = self.explainer(x, self.model, task_id)

                self.saliency_size = explanations.size()

                # To make predicted explanations (Bx7x7) same as ground truth ones (Bx1x7x7)
                sal_loss = self.sal_loss(explanations.view_as(s), s)
                sal_loss *= self.args.saliency.regularizer

                if self.args.wandb.log:
                    wandb.log({"Saliency loss": sal_loss.item()})

                try:
                    sal_loss.requires_grad = True
                except:
                    continue

                self.optimizer_explanations.zero_grad()
                sal_loss.backward(retain_graph=True)
                self.optimizer_explanations.step()

        # Loop batches
        for batch_idx, (x, y, tt) in enumerate(train_loader):

            images = x.to(device=self.device, dtype=torch.float)
            targets = y.to(device=self.device, dtype=torch.long)
            tt = tt.to(device=self.device, dtype=torch.long)
            # Forward
            if self.args.architecture.multi_head:
                output=self.model.forward(images, tt)
            else:
                output = self.model.forward(images)

            loss=self.criterion(output,targets)

            # L1 regularize
            if self.args.train.l1_reg:
                reg_loss = self.l1_regularizer()
                factor = self.args.train.l1_reg_factor
                loss += factor * reg_loss

            loss *= self.args.train.task_loss_reg


            # Backward
            self.optimizer.zero_grad()
            loss.backward(retain_graph=True)

            # Apply step
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(),self.clipgrad)
            self.optimizer.step()


    def eval(self,data_loader, task_id, set_name='valid'):
        total_loss=0
        total_acc=0
        total_num=0
        self.model.eval()
        res={}
        old_tasks_loss, sal_loss = 0, 0

        # Loop batches
        with torch.no_grad():

            for batch_idx, (x, y, tt) in enumerate(data_loader):
                # Fetch x and y labels
                images=x.to(device=self.device, dtype=torch.float)
                targets=y.to(device=self.device, dtype=torch.long)
                tt=tt.to(device=self.device, dtype=torch.long)

                # Forward
                if self.args.architecture.multi_head:
                    output = self.model.forward(images, tt)
                else:
                    output = self.model.forward(images)

                loss = self.criterion(output,targets)
                _, pred=output.max(1)
                hits = (pred==targets).float()

                # Log
                total_loss += loss
                total_acc += hits.sum().item()
                total_num += targets.size(0)

        res['loss'], res['acc'] = total_loss/(batch_idx+1), 100*total_acc/total_num
        res['size'] = self.loader_size(data_loader)

        return res


    def test(self, model, test_id, model_id=None):
        total_loss=0
        total_acc=0
        total_num=0
        model.eval()
        res={}

        # Loop batches
        with torch.no_grad():

            for batch_idx, (x, y, tt) in enumerate(self.test_loader[test_id]):
                # Fetch x and y labels
                images=x.to(device=self.device, dtype=torch.float32)
                targets=y.to(device=self.device, dtype=torch.long)

                # Forward
                # output= model.forward(images,test_id)
                if self.args.architecture.multi_head:
                    output = self.model.forward(images, test_id)
                else:
                    output = self.model.forward(images)

                _, pred = output.max(1)

                loss=self.criterion(output,targets)
                hits=(pred==targets).float()

                # Log
                total_loss+=loss
                total_acc+=hits.sum().item()
                total_num+=targets.size(0)

        res['loss'], res['acc'] = total_loss/(batch_idx+1), 100*total_acc/total_num


        return res['loss'], res['acc']


    def update_memory(self, task_id):
        start = time.time()
        # Get memory set for each task seen so far with the updated samples per class and return new spc
        self.dataset.get_memory_sets(task_id)
        midway = time.time()
        print('[Storing memory time = {:.1f} min ]'.format((midway - start) / (60)))


        if self.args.experiment.xai_memory:
            self.update_saliencies(task_id)
            self.saliency_loaders = self.dataset.generate_evidence_loaders(task_id)
            # Be careful if you comment out the sanity check below. It extremely slows down the training
            # self.check_saliency_loaders(task_id)
        print('[Storing saliency time = {:.1f} min ]'.format((time.time() - midway) / (60)))

    def update_saliencies(self, task_id):
        save_images = False
        ims = [] if save_images else None
        images, preds = [], []

        # Generate saliency for images from the last seen task
        memory_path = os.path.join(self.args.path.checkpoint, 'memory', 'mem_{}.pth'.format(task_id))
        memory_set = torch.load(memory_path)
        num_samples = len(memory_set)
        single_loader = torch.utils.data.DataLoader(memory_set, batch_size=1,
                                                    num_workers=self.args.device.workers,
                                                    shuffle=False)

        saliencies, predictions = [], []
        for idx, (img, y, tt) in enumerate(single_loader):

            img = img.to(self.args.device.name)
            sal, self.model, _, _ = self.explainer(img, self.model, task_id)
            if self.args.architecture.multi_head:
                output = self.model.forward(img, task_id)
            else:
                output = self.model.forward(img)

            _, pred = output.max(1)

            saliencies.append(sal)
            predictions.append(pred)

        sal_path = os.path.join(self.args.path.checkpoint, 'memory', 'sal_{}.pth'.format(task_id))
        pred_path = os.path.join(self.args.path.checkpoint, 'memory', 'pred_{}.pth'.format(task_id))
        torch.save(saliencies, sal_path)
        torch.save(predictions, pred_path)

        if not self.args.experiment.fscil:
            # Reduce previous saliencies
            for t in range(task_id):
                # Read the stored saliency file
                sal_path = os.path.join(self.args.path.checkpoint, 'memory', 'sal_{}.pth'.format(t))
                saliencies = torch.load(sal_path)
                before = len(saliencies)

                pred_path = os.path.join(self.args.path.checkpoint, 'memory', 'pred_{}.pth'.format(t))
                predictions = torch.load(pred_path)

                # Extract the required number of samples and save them again
                saliencies = saliencies[:num_samples]
                after = len(saliencies)
                torch.save(saliencies, sal_path)
                print ("Reduced saliencies for task {} from {} to {}".format(t, before, after))

                predictions = predictions[:num_samples]
                torch.save(predictions, pred_path)



    def l1_regularizer(self):
        reg_loss = 0
        for param in self.model.parameters():
            target = torch.zeros_like(param)
            reg_loss += self.l1_reg(param, target)
        return reg_loss


    def save_model(self,t,best_model):
        torch.save({'model_state_dict': best_model,
                    }, os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,t)))


    def loader_size(self,data_loader):
        return data_loader.dataset.__len__()


    def load_model(self, task_id):
        net=self.network.Net(self.args).to(device=self.args.device.name)
        # net = self.network._CustomDataParallel(net, self.args.device.name_ids)
        if self.args.device.multi:
            net = torch.nn.DataParallel(net)
        checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,task_id)))
        net.load_state_dict(checkpoint['model_state_dict'])
        net = net.to(device=self.args.device.name)
        # net = self.network._CustomDataParallel(net, self.args.device.name_ids)
        return net

    def load_singlehead_model(self,current_model_id):
        return self.model

    def load_multihead_model(self, test_id, current_model_id):
        # Load a previous model
        old_model=self.network.Net(self.args)
        if self.args.device.multi:
            old_model = torch.nn.DataParallel(old_model)
        checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,test_id)))
        old_model.load_state_dict(checkpoint['model_state_dict'])

        # Load a current model
        current_model=self.network.Net(self.args)
        if self.args.device.multi:
            current_model = torch.nn.DataParallel(current_model)
        checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,current_model_id)))
        current_model.load_state_dict(checkpoint['model_state_dict'])

        # Change the current_model head with the old head
        if self.args.device.multi:
            old_head=deepcopy(old_model.module.head.state_dict())
            current_model.module.head.load_state_dict(old_head)
        else:
            old_head = deepcopy(old_model.head.state_dict())
            current_model.head.load_state_dict(old_head)

        current_model=current_model.to(self.args.device.name)
        return current_model



    def load_checkpoint(self, task_id):
        print("Loading checkpoint for task {} ...".format(task_id))

        # Load a previous model
        net=self.network.Net(self.args)
        net = net.to(device=self.args.device.name)

        checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,task_id)))
        net.load_state_dict(checkpoint['model_state_dict'])
        net=net.to(device=self.args.device.name)

        return net


    def check_saliency_loaders(self, task_id):

        path = os.path.join(self.args.path.checkpoint, 'memory')
        for idx, (images, targets, saliencies, tt, preds) in enumerate(self.saliency_loaders):
            # loop over batch
            for i in range(len(images)):

                fig = plt.figure(dpi=100)
                # fig.subplots_adjust(hspace=0.01, wspace=0.01)

                img = images[i].unsqueeze(0)  # [1, 3, 224, 224]
                saliency = saliencies[i].unsqueeze(0)  # shape: [1, 1, 7, 7]
                pred = preds[i]  # shape: [1, 7, 7]
                target = targets[i] #
                task = tt[i]
                # saliency = saliency.unsqueeze(0)  # shape:
                image_size = img.shape[2:] # [32x32]
                saliency = F.interpolate(saliency, size=image_size, mode="bilinear", align_corners=False)

                B, C, H, W = saliency.shape
                saliency = saliency.view(B, -1)
                saliency_max = saliency.max(dim=1, keepdim=True)[0]
                saliency_max[torch.where(saliency_max == 0)] = 1.  # prevent divide by 0
                saliency -= saliency.min(dim=1, keepdim=True)[0]
                saliency /= saliency_max
                saliency = saliency.view(B, C, H, W)
                saliency = saliency.squeeze(0).squeeze(0)
                saliency = saliency.detach().cpu().numpy()

                img = img.squeeze(0)
                img = img.cpu().numpy().transpose((1, 2, 0))  # (224, 224, 3)
                mean = np.array(self.dataset.mean)
                std = np.array(self.dataset.std)
                img = std * img + mean
                img = np.clip(img, 0, 1)

                # plt.subplot(1, 10, idx + 1)
                plt.axis('off')
                result = 'Correct' if pred.item() == target else "Wrong"
                plt.title('{} | Pred:{}, Truth:{}, Task:{}'.format(result, pred.item(), target, task), fontsize=9)

                plt.imshow(img)
                plt.savefig(os.path.join(path, 'Img-batch-{}-ID-{}-task-{}.png'.format(idx, i, task_id)))
                plt.imshow(saliency, cmap='jet', alpha=0.5)
                plt.colorbar(fraction=0.046, pad=0.04)

                fig.tight_layout()
                plt.savefig(os.path.join(path, 'Sal-batch-{}-ID-{}-task-{}.png'.format(idx, i, task_id)))
                plt.close()





    def visualize(self, images, saliencies, task_id, preds, y, path):

        # fig = plt.figure(figsize=(10, 2), dpi=500)
        # fig.subplots_adjust(hspace=0.01, wspace=0.01)

        for idx in range(len(saliencies)):
            saliency, img = saliencies[idx], images[idx]

            # img = img.unsqueeze(0)  # [1, 3, 224, 224]
            saliency = saliency.unsqueeze(0)  # shape: [1, 7, 7]
            # saliency = saliency.unsqueeze(0)  # shape: [1, 1, 7, 7]

            image_size = img.shape[2:]
            # print (saliency.size(), img.size(), image_size)
            saliency = F.interpolate(saliency, size=image_size, mode="bilinear", align_corners=False)

            B, C, H, W = saliency.shape
            saliency = saliency.view(B, -1)
            saliency_max = saliency.max(dim=1, keepdim=True)[0]
            saliency_max[torch.where(saliency_max == 0)] = 1.  # prevent divide by 0
            saliency -= saliency.min(dim=1, keepdim=True)[0]
            saliency /= saliency_max
            saliency = saliency.view(B, C, H, W)
            saliency = saliency.squeeze(0).squeeze(0)
            saliency = saliency.detach().cpu().numpy()

            img = img.squeeze(0)
            img = img.cpu().numpy().transpose((1, 2, 0))  # (224, 224, 3)
            mean = np.array(self.dataset.mean)
            std = np.array(self.dataset.std)
            img = std * img + mean
            img = np.clip(img, 0, 1)

            fig = plt.figure(dpi=500)
            # plt.subplot(1,1, 1)
            plt.axis('off')

            result = 'Correct' if preds[idx].item() == y[idx].item() else "Wrong"
            plt.title('{}-Pred:{},Truth:{}'.format(result, preds[idx].item(), y[idx].item()), fontsize=7)
            plt.imshow(img)
            plt.imshow(saliency, cmap='jet', alpha=0.5)
            plt.colorbar(fraction=0.046, pad=0.04)
            fig.tight_layout()

            print('saving to: {}'.format(os.path.join(path, 'sal-task-{}-mem-{}.png'.format(task_id, idx))))
            plt.savefig(os.path.join(path, 'sal-task-{}-mem-{}.png'.format(task_id, idx)))
            plt.close()

    def compute_memory(self):
        ssize = 1
        print ("***"*200)
        print ("saliency_size", self.saliency_size)
        for s in self.saliency_size:
            ssize *= s
        saliency_memory = 4 * self.args.experiment.memory_budget * ssize

        ncha, size, _ = self.args.inputsize
        image_size = ncha * size * size

        samples_memory = 4 * self.args.experiment.memory_budget * image_size
        count = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

        print('Num parameters in the entire model    = %s ' % (utils.human_format(count)))
        architecture_memory = 4 * count

        print("-------------------------->  Saliency memory size: (%sB)" % utils.human_format(saliency_memory))
        print("-------------------------->  Episodic memory size: (%sB)" % utils.human_format(samples_memory))
        print("------------------------------------------------------------------------------")
        print("                             TOTAL:  %sB" % utils.human_format(
            architecture_memory+samples_memory+saliency_memory))