def setup_train(self, model_path=None): self.loss_weight = torch.FloatTensor([1 - config.alpha, config.alpha]) if self.cuda: self.loss_weight = self.loss_weight.cuda() self.model = HierSumTransformer(self.vocab, config.emb_dim, config.d_model, config.N, config.heads, config.max_sent_len, config.max_doc_len) if model_path: self.model.load_state_dict( torch.load(model_path, map_location=lambda storage, location: storage)) else: for p in self.model.parameters(): if p.dim() > 1: torch.nn.init.xavier_uniform_(p) # Data Parallel if self.cuda: if self.n_gpu == 1: pass elif self.n_gpu > 1: self.model = torch.nn.DataParallel(self.model) self.model = self.model.cuda() #self.optim = torch.optim.Adam(self.model.parameters(), lr=config.lr, betas=(0.9, 0.98), eps=1e-9) self.optim = RAdam(self.model.parameters(), lr=config.lr, betas=(0.9, 0.98), eps=1e-9) print('* Training model is prepared')
def __init__(self, param): args = param.args net = Network(param) self.optimizer = RAdam(net.get_parameters_by_name(), lr=args.lr) optimizerList = {"optimizer": self.optimizer} checkpoint_manager = CheckpointManager(args.name, args.model_dir, \ args.checkpoint_steps, args.checkpoint_max_to_keep, "min") super().__init__(param, net, optimizerList, checkpoint_manager) self.create_summary()
def get_optimizer(self,task_id, lr=None): if lr is None: lr=self.lrs[task_id] if (self.args.train.optimizer=="radam"): self.optimizer = RAdam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0) elif(self.args.train.optimizer=="adam"): self.optimizer= torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False) elif(self.args.train.optimizer=="sgd"): self.optimizer= torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9, weight_decay=0.001) self.scheduler_opt = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.lr_patience, factor=self.lr_factor / 10, min_lr=self.lr_min[task_id], verbose=True)
def _get_optim(self): #self.optimizer = optim.SGD(self.parameters(), lr=self.config['lr'], momentum=self.config['momentum'], # weight_decay=self.config['weight_decay']) self.optimizer = RAdam(self.parameters(), lr=self.config['lr'], weight_decay=self.config['weight_decay']) s = self.config['lr_steps'].split(',') milestones = [] for i in s: milestones.append(int(i)) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=milestones, gamma=self.config['lr_gamma'])
def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs args.warmup_steps = t_total // 100 # Prepare optimizer and schedule (linear warmup and decay) optimizer_grouped_parameters = get_param_groups(args, model) optimizer = RAdam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelModel(model) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) args.logging_steps = len(train_dataloader) // 1 args.save_steps = args.logging_steps global_step = 0 tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch") set_seed(args) for _ in train_iterator: args.current_epoch = _ epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, batch in enumerate(epoch_iterator): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, } # XLM and RoBERTa don't use segment_ids # 'labels': batch[3]} outputs = model(**inputs) outputs = [outputs[i][0] for i in range(len(outputs))] loss_fct = CrossEntropyLoss() loss_fct = DataParallelCriterion(loss_fct) loss = loss_fct(outputs, batch[3]) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, 'checkpoint-{}'.format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = model.module if hasattr( model, 'module' ) else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
class GPT2LM(BaseModel): def __init__(self, param): args = param.args net = Network(param) self.optimizer = RAdam(net.get_parameters_by_name(), lr=args.lr) optimizerList = {"optimizer": self.optimizer} checkpoint_manager = CheckpointManager(args.name, args.model_dir, \ args.checkpoint_steps, args.checkpoint_max_to_keep, "min") super().__init__(param, net, optimizerList, checkpoint_manager) self.create_summary() def create_summary(self): args = self.param.args self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), \ args) self.trainSummary = self.summaryHelper.addGroup(\ scalar=["loss", "word_loss", "perplexity"],\ prefix="train") scalarlist = ["word_loss", "perplexity_avg_on_batch"] tensorlist = [] textlist = [] emblist = [] for i in self.args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(\ scalar=scalarlist,\ tensor=tensorlist,\ text=textlist,\ embedding=emblist,\ prefix="dev") self.testSummary = self.summaryHelper.addGroup(\ scalar=scalarlist,\ tensor=tensorlist,\ text=textlist,\ embedding=emblist,\ prefix="test") def _preprocess_batch(self, data): incoming = Storage() incoming.data = data = Storage(data) data.batch_size = data.sent.shape[0] data.sent = cuda(torch.LongTensor(data.sent)) # length * batch_size data.sent_attnmask = zeros(*data.sent.shape) for i, length in enumerate(data.sent_length): data.sent_attnmask[i, :length] = 1 return incoming def get_next_batch(self, dm, key, restart=True): data = dm.get_next_batch(key) if data is None: if restart: dm.restart(key) return self.get_next_batch(dm, key, False) else: return None return self._preprocess_batch(data) def get_batches(self, dm, key): batches = list( dm.get_batches(key, batch_size=self.args.batch_size, shuffle=False)) return len(batches), (self._preprocess_batch(data) for data in batches) def get_select_batch(self, dm, key, i): data = dm.get_batch(key, i) if data is None: return None return self._preprocess_batch(data) def train(self, batch_num): args = self.param.args dm = self.param.volatile.dm datakey = 'train' for i in range(batch_num): self.now_batch += 1 incoming = self.get_next_batch(dm, datakey) incoming.args = Storage() if (i + 1) % args.batch_num_per_gradient == 0: self.zero_grad() self.net.forward(incoming) loss = incoming.result.loss self.trainSummary(self.now_batch, storage_to_list(incoming.result)) logging.info("batch %d : gen loss=%f", self.now_batch, loss.detach().cpu().numpy()) loss.backward() if (i + 1) % args.batch_num_per_gradient == 0: nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip) self.optimizer.step() def evaluate(self, key): args = self.param.args dm = self.param.volatile.dm dm.restart(key, args.batch_size, shuffle=False) result_arr = [] while True: incoming = self.get_next_batch(dm, key, restart=False) if incoming is None: break incoming.args = Storage() with torch.no_grad(): self.net.forward(incoming) result_arr.append(incoming.result) detail_arr = Storage() for i in args.show_sample: index = [i * args.batch_size + j for j in range(args.batch_size)] incoming = self.get_select_batch(dm, key, index) incoming.args = Storage() with torch.no_grad(): self.net.detail_forward(incoming) detail_arr["show_str%d" % i] = incoming.result.show_str detail_arr.update( {key: get_mean(result_arr, key) for key in result_arr[0]}) detail_arr.perplexity_avg_on_batch = np.exp(detail_arr.word_loss) return detail_arr def train_process(self): args = self.param.args dm = self.param.volatile.dm while self.now_epoch < args.epochs: self.now_epoch += 1 self.updateOtherWeights() dm.restart('train', args.batch_size) self.net.train() self.train(args.batch_per_epoch) self.net.eval() devloss_detail = self.evaluate("dev") self.devSummary(self.now_batch, devloss_detail) logging.info("epoch %d, evaluate dev", self.now_epoch) testloss_detail = self.evaluate("test") self.testSummary(self.now_batch, testloss_detail) logging.info("epoch %d, evaluate test", self.now_epoch) self.save_checkpoint(value=devloss_detail["word_loss"].tolist()) def test(self, key): args = self.param.args dm = self.param.volatile.dm metric1 = dm.get_teacher_forcing_metric() batch_num, batches = self.get_batches(dm, key) logging.info("eval teacher-forcing") for incoming in tqdm.tqdm(batches, total=batch_num): incoming.args = Storage() with torch.no_grad(): self.net.forward(incoming) gen_log_prob = nn.functional.log_softmax(incoming.gen.w, -1) data = incoming.data data.sent_allvocabs = LongTensor(incoming.data.sent_allvocabs) data.sent_length = incoming.data.sent_length data.gen_log_prob = gen_log_prob metric1.forward(data) res = metric1.close() metric2 = dm.get_inference_metric() batch_num, batches = self.get_batches(dm, key) logging.info("eval free-run") for incoming in tqdm.tqdm(batches, total=batch_num): incoming.args = Storage() with torch.no_grad(): self.net.detail_forward(incoming) data = incoming.data data.gen = incoming.gen.w_o.detach().cpu().numpy() metric2.forward(data) res.update(metric2.close()) if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) filename = args.out_dir + "/%s_%s.txt" % (args.name, key) with open(filename, 'w') as f: logging.info("%s Test Result:", key) for key, value in res.items(): if isinstance(value, float) or isinstance(value, str): logging.info("\t{}:\t{}".format(key, value)) f.write("{}:\t{}\n".format(key, value)) for i in range(len(res['gen'])): f.write("gen:\t%s\n" % " ".join(res['gen'][i])) f.flush() logging.info("result output to %s.", filename) return { key: val for key, val in res.items() if isinstance(val, (str, int, float)) } def test_process(self): logging.info("Test Start.") self.net.eval() self.test("train") self.test("dev") test_res = self.test("test") logging.info("Test Finish.") return test_res
class TrainOperator: def __init__(self): # source self.tok = sp.SentencePieceProcessor() self.tok.Load(config.tok_path) self.vocab = self.tok.GetPieceSize() self.pad = self.tok.piece_to_id('[PAD]') self.num_workers = multiprocessing.cpu_count() self.cuda = config.cuda and torch.cuda.is_available() # for data parallel if self.cuda: self.n_gpu = torch.cuda.device_count() else: self.n_gpu = 0 # load loader self.train_loader = self._construct_loader('train') self.dev_loader = self._construct_loader('dev') print('* Train Operator is loaded') def setup_train(self, model_path=None): self.loss_weight = torch.FloatTensor([1 - config.alpha, config.alpha]) if self.cuda: self.loss_weight = self.loss_weight.cuda() self.model = HierSumTransformer(self.vocab, config.emb_dim, config.d_model, config.N, config.heads, config.max_sent_len, config.max_doc_len) if model_path: self.model.load_state_dict( torch.load(model_path, map_location=lambda storage, location: storage)) else: for p in self.model.parameters(): if p.dim() > 1: torch.nn.init.xavier_uniform_(p) # Data Parallel if self.cuda: if self.n_gpu == 1: pass elif self.n_gpu > 1: self.model = torch.nn.DataParallel(self.model) self.model = self.model.cuda() #self.optim = torch.optim.Adam(self.model.parameters(), lr=config.lr, betas=(0.9, 0.98), eps=1e-9) self.optim = RAdam(self.model.parameters(), lr=config.lr, betas=(0.9, 0.98), eps=1e-9) print('* Training model is prepared') def train(self): printInterval = 20 init_loss = 1e5 init_f1 = 0 for n in range(config.n_epoch): loss_tr_total = 0 for batch_id, batch in enumerate(self.train_loader): loss_tr = self._train_one_batch(batch) loss_tr_total += loss_tr if (batch_id + 1) % printInterval == 0 or batch_id == 0: loss_tr = round(loss_tr, 4) loss_eval, recall, pre, f1 = [ round(l, 4) for l in self._evaluate() ] print( "| epoch: {} | batch: {}/{}| tr_loss: {} | val_loss: {} |" .format(n + 1, batch_id + 1, len(self.train_loader), round(loss_tr_total / (batch_id + 1), 4), loss_eval)) print("| epoch: {} |Recall: {} | Precision: {} | F1: {} |". format(n + 1, recall, pre, f1)) print("-" * 100) if loss_eval < init_loss: init_loss = loss_eval if self.n_gpu <= 1: torch.save(self.model.state_dict(), './resource/RNN_TR_HiSum_v2.0.pkl') elif self.n_gpu > 1: torch.save(self.model.module.state_dict(), './resource/RNN_TR_HiSum_v2.0.pkl') # change model state to train self.model.train() def _construct_loader(self, type): dataset = ExtSumDataset(config.data_path, self.tok, type) loader = DataLoader( dataset, batch_size=config.batch_size, num_workers=0, ) return loader def _train_one_batch(self, batch): doc_id, doc_len, sent_len, label, doc_mask = batch doc_mask = doc_mask.unsqueeze(1) sent_mask = torch.stack([self._create_mask(sent) for sent in doc_id]) if self.cuda: doc_id = doc_id.cuda() doc_len = doc_len.cuda() doc_mask = doc_mask.cuda() sent_mask = sent_mask.cuda() sent_len = sent_len.cuda() label = label.cuda() preds = self.model(doc_id, sent_mask, doc_mask, sent_len) loss = F.cross_entropy(preds.view(-1, preds.size(-1)), label.reshape(-1), ignore_index=config.ignore_index_ext, weight=self.loss_weight) #loss = focal_loss(preds.view(-1, preds.size(-1)), label.reshape(-1), ignore_index=config.ignore_index_ext, alpha = config.alpha, gamma = config.gamma) loss.backward() self.optim.step() return loss.tolist() def _evaluate(self): right = 0 origin = 0 found = 0 total_loss = 0 self.model.eval() for i, data in enumerate(self.dev_loader): doc_id, doc_len, sent_len, label, doc_mask = data doc_mask = doc_mask.unsqueeze(1) sent_mask = torch.stack( [self._create_mask(sent) for sent in doc_id]) if self.cuda: doc_id = doc_id.cuda() doc_len = doc_len.cuda() doc_mask = doc_mask.cuda() sent_mask = sent_mask.cuda() sent_len = sent_len.cuda() label = label.cuda() preds = self.model(doc_id, sent_mask, doc_mask, sent_len) loss = F.cross_entropy(preds.view(-1, preds.size(-1)), label.reshape(-1), ignore_index=config.ignore_index_ext, weight=self.loss_weight) #loss = focal_loss(preds.view(-1, preds.size(-1)), label.reshape(-1), ignore_index=config.ignore_index_ext, alpha = config.alpha, gamma = config.gamma) total_loss += loss.tolist() pred_label = [torch.argmax(p, 1).tolist() for p in preds] labels = label.tolist() for p_tag, label in zip(pred_label, labels): for p, l in zip(p_tag, label): if l == config.ignore_index_ext: break elif p == 1 and l == 1: right += 1 origin += 1 found += 1 elif p == 0 and l == 1: origin += 1 elif p == 1 and l == 0: found += 1 else: pass recall = (right / (origin + 1e-5)) precision = (right / (found + 1e-5)) f1 = (2 * precision * recall) / (precision + recall + 1e-5) return round(total_loss / (i + 1), 4), round(recall, 4), round(precision, 4), round(f1, 4) def _create_mask(self, tok_ids): mask = (tok_ids != self.pad).unsqueeze(1) return mask
class RRR(object): def __init__(self,model,args,dataset,network,clipgrad=10000): self.args=args self.dataset = dataset self.model=model self.nepochs=args.train.nepochs self.sbatch=args.train.batch_size self.lrs = [args.train.lr for _ in range(args.experiment.ntasks)] if self.args.experiment.fscil: self.lrs[1:] = [self.args.experiment.lr_multiplier * lr for lr in self.lrs[1:]] # self.lrs = [item[1] for item in args.lrs] self.lrs_exp = [args.saliency.lr for _ in range(args.experiment.ntasks)] self.lr_min=[lr/1000. for lr in self.lrs] self.lr_factor=args.train.lr_factor self.lr_patience=args.train.lr_patience self.clipgrad=clipgrad self.checkpoint=args.path.checkpoint self.device=args.device.name # self.args.train.schedule = [20, 30,40] self.args.train.schedule = [20, 40, 60] self.criterion=torch.nn.CrossEntropyLoss().to(device=self.args.device.name) if self.args.saliency.loss == "l1": self.sal_loss = torch.nn.L1Loss().to(device=self.args.device.name) elif self.args.saliency.loss == "l2": self.sal_loss = torch.nn.MSELoss().to(device=self.args.device.name) else: raise NotImplementedError if self.args.train.l1_reg: self.l1_reg = torch.nn.L1Loss(reduction='sum') self.get_optimizer(task_id=0) self.get_optimizer_explanations(task_id=0) self.network=network self.inputsize=args.inputsize self.taskcla=args.taskcla self.memory_loaders = {} self.test_loader = {} self.memory_paths = [] self.saliency_loaders = None if self.args.experiment.raw_memory_only or self.args.experiment.xai_memory: self.use_memory = True else: self.use_memory = False # XAI if self.args.saliency.method == 'gc': print ("Using GradCAM to obtain saliency maps") from approaches.explanations import GradCAM as Explain elif self.args.saliency.method == 'smooth': print ("Using SmoothGrad to obtain saliency maps") from approaches.explanations import SmoothGrad as Explain elif self.args.saliency.method == 'bp': print ("Using BackPropagation to obtain saliency maps") from approaches.explanations import BackPropagation as Explain elif self.args.saliency.method == 'gbp': print ("Using Guided BackPropagation to obtain saliency maps") from approaches.explanations import GuidedBackPropagation as Explain elif self.args.saliency.method == 'deconv': from approaches.explanations import Deconvnet as Explain self.explainer = Explain(self.args) def get_optimizer(self,task_id, lr=None): if lr is None: lr=self.lrs[task_id] if (self.args.train.optimizer=="radam"): self.optimizer = RAdam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0) elif(self.args.train.optimizer=="adam"): self.optimizer= torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False) elif(self.args.train.optimizer=="sgd"): self.optimizer= torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9, weight_decay=0.001) self.scheduler_opt = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.lr_patience, factor=self.lr_factor / 10, min_lr=self.lr_min[task_id], verbose=True) def adjust_learning_rate(self, epoch): if epoch in self.args.train.schedule: for param_group in self.optimizer.param_groups: param_group['lr'] *= self.args.train.gamma print("Reducing learning rate to ", param_group['lr']) def get_optimizer_explanations(self, task_id, lr=None): if lr is None: lr=self.lrs_exp[task_id] if self.args.train.optimizer=="sgd": self.optimizer_explanations = torch.optim.SGD(self.model.parameters(), lr=lr, weight_decay=self.args.train.wd) self.scheduler_exp_opt = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer_explanations, patience=self.lr_patience, factor=self.lr_factor/10, min_lr=self.lr_min[task_id], verbose=True) elif(self.args.train.optimizer=="adam"): self.optimizer_explanations= torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False) elif (self.args.train.optimizer=="radam"): self.optimizer_explanations = RAdam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0) def train(self,task_id,performance): # if task_id > 0 and not self.args.architecture.multi_head: # self.model.increment_classes(task_id) # self.model = self.model.to(self.device) print('*'*100) data_loader = self.dataset.get(task_id=task_id) self.test_loader[task_id] = data_loader['test'] print(' '*85, 'Run #{:2d} - Dataset {:2d} ({:s})'.format(self.args.seed+1, task_id+1,data_loader['name'])) print('*'*100) best_loss=np.inf if self.args.train.nepochs > 1: best_model=utils.get_model(self.model) lr = self.lrs[task_id] lr_exp = self.lrs_exp[task_id] patience=self.lr_patience best_acc = 0 self.get_optimizer(task_id=task_id) self.get_optimizer_explanations(task_id=task_id) # Loop epochs for e in range(self.nepochs): self.epoch = e # Train clock0=time.time() self.train_epoch(data_loader['train'], task_id) clock1=time.time() # Valid if self.args.train.pc_valid > 0: valid_res = self.eval(data_loader['valid'], task_id, set_name='valid') utils.report_valid(valid_res) else: valid_res = self.eval(data_loader['test'], task_id, set_name='test') print ("Epoch {}/{} | Test acc {}: {:.2f}".format(e+1, self.nepochs, task_id, valid_res['acc'])) if self.args.wandb.log: wandb.log({"Best Acc on Task {}".format(task_id): valid_res['acc']}) if (self.args.optimizer == "sgd"): self.scheduler_opt.step(valid_res['loss']) self.scheduler_exp_opt.step(valid_res['loss']) is_best = valid_res['acc'] > best_acc if is_best: best_model = utils.get_model(self.model) best_acc = max(valid_res['acc'], best_acc) # Restore best validation model if self.args.train.nepochs > 1: self.model.load_state_dict(deepcopy(best_model)) self.save_model(task_id, deepcopy(self.model.state_dict())) if task_id == 1 and self.args.experiment.xai_memory: self.compute_memory() if self.use_memory and task_id < self.args.experiment.ntasks: self.update_memory(task_id) def train_epoch(self,train_loader,task_id): self.adjust_learning_rate(self.epoch) self.model.train() if task_id > 0 and self.args.experiment.xai_memory: for idx, (data, target, sal, tt, _) in enumerate(self.saliency_loaders): x = data.to(device=self.device, dtype=torch.float) s = sal.to(device=self.device, dtype=torch.float) explanations, self.model , _, _ = self.explainer(x, self.model, task_id) self.saliency_size = explanations.size() # To make predicted explanations (Bx7x7) same as ground truth ones (Bx1x7x7) sal_loss = self.sal_loss(explanations.view_as(s), s) sal_loss *= self.args.saliency.regularizer if self.args.wandb.log: wandb.log({"Saliency loss": sal_loss.item()}) try: sal_loss.requires_grad = True except: continue self.optimizer_explanations.zero_grad() sal_loss.backward(retain_graph=True) self.optimizer_explanations.step() # Loop batches for batch_idx, (x, y, tt) in enumerate(train_loader): images = x.to(device=self.device, dtype=torch.float) targets = y.to(device=self.device, dtype=torch.long) tt = tt.to(device=self.device, dtype=torch.long) # Forward if self.args.architecture.multi_head: output=self.model.forward(images, tt) else: output = self.model.forward(images) loss=self.criterion(output,targets) # L1 regularize if self.args.train.l1_reg: reg_loss = self.l1_regularizer() factor = self.args.train.l1_reg_factor loss += factor * reg_loss loss *= self.args.train.task_loss_reg # Backward self.optimizer.zero_grad() loss.backward(retain_graph=True) # Apply step # torch.nn.utils.clip_grad_norm_(self.model.parameters(),self.clipgrad) self.optimizer.step() def eval(self,data_loader, task_id, set_name='valid'): total_loss=0 total_acc=0 total_num=0 self.model.eval() res={} old_tasks_loss, sal_loss = 0, 0 # Loop batches with torch.no_grad(): for batch_idx, (x, y, tt) in enumerate(data_loader): # Fetch x and y labels images=x.to(device=self.device, dtype=torch.float) targets=y.to(device=self.device, dtype=torch.long) tt=tt.to(device=self.device, dtype=torch.long) # Forward if self.args.architecture.multi_head: output = self.model.forward(images, tt) else: output = self.model.forward(images) loss = self.criterion(output,targets) _, pred=output.max(1) hits = (pred==targets).float() # Log total_loss += loss total_acc += hits.sum().item() total_num += targets.size(0) res['loss'], res['acc'] = total_loss/(batch_idx+1), 100*total_acc/total_num res['size'] = self.loader_size(data_loader) return res def test(self, model, test_id, model_id=None): total_loss=0 total_acc=0 total_num=0 model.eval() res={} # Loop batches with torch.no_grad(): for batch_idx, (x, y, tt) in enumerate(self.test_loader[test_id]): # Fetch x and y labels images=x.to(device=self.device, dtype=torch.float32) targets=y.to(device=self.device, dtype=torch.long) # Forward # output= model.forward(images,test_id) if self.args.architecture.multi_head: output = self.model.forward(images, test_id) else: output = self.model.forward(images) _, pred = output.max(1) loss=self.criterion(output,targets) hits=(pred==targets).float() # Log total_loss+=loss total_acc+=hits.sum().item() total_num+=targets.size(0) res['loss'], res['acc'] = total_loss/(batch_idx+1), 100*total_acc/total_num return res['loss'], res['acc'] def update_memory(self, task_id): start = time.time() # Get memory set for each task seen so far with the updated samples per class and return new spc self.dataset.get_memory_sets(task_id) midway = time.time() print('[Storing memory time = {:.1f} min ]'.format((midway - start) / (60))) if self.args.experiment.xai_memory: self.update_saliencies(task_id) self.saliency_loaders = self.dataset.generate_evidence_loaders(task_id) # Be careful if you comment out the sanity check below. It extremely slows down the training # self.check_saliency_loaders(task_id) print('[Storing saliency time = {:.1f} min ]'.format((time.time() - midway) / (60))) def update_saliencies(self, task_id): save_images = False ims = [] if save_images else None images, preds = [], [] # Generate saliency for images from the last seen task memory_path = os.path.join(self.args.path.checkpoint, 'memory', 'mem_{}.pth'.format(task_id)) memory_set = torch.load(memory_path) num_samples = len(memory_set) single_loader = torch.utils.data.DataLoader(memory_set, batch_size=1, num_workers=self.args.device.workers, shuffle=False) saliencies, predictions = [], [] for idx, (img, y, tt) in enumerate(single_loader): img = img.to(self.args.device.name) sal, self.model, _, _ = self.explainer(img, self.model, task_id) if self.args.architecture.multi_head: output = self.model.forward(img, task_id) else: output = self.model.forward(img) _, pred = output.max(1) saliencies.append(sal) predictions.append(pred) sal_path = os.path.join(self.args.path.checkpoint, 'memory', 'sal_{}.pth'.format(task_id)) pred_path = os.path.join(self.args.path.checkpoint, 'memory', 'pred_{}.pth'.format(task_id)) torch.save(saliencies, sal_path) torch.save(predictions, pred_path) if not self.args.experiment.fscil: # Reduce previous saliencies for t in range(task_id): # Read the stored saliency file sal_path = os.path.join(self.args.path.checkpoint, 'memory', 'sal_{}.pth'.format(t)) saliencies = torch.load(sal_path) before = len(saliencies) pred_path = os.path.join(self.args.path.checkpoint, 'memory', 'pred_{}.pth'.format(t)) predictions = torch.load(pred_path) # Extract the required number of samples and save them again saliencies = saliencies[:num_samples] after = len(saliencies) torch.save(saliencies, sal_path) print ("Reduced saliencies for task {} from {} to {}".format(t, before, after)) predictions = predictions[:num_samples] torch.save(predictions, pred_path) def l1_regularizer(self): reg_loss = 0 for param in self.model.parameters(): target = torch.zeros_like(param) reg_loss += self.l1_reg(param, target) return reg_loss def save_model(self,t,best_model): torch.save({'model_state_dict': best_model, }, os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,t))) def loader_size(self,data_loader): return data_loader.dataset.__len__() def load_model(self, task_id): net=self.network.Net(self.args).to(device=self.args.device.name) # net = self.network._CustomDataParallel(net, self.args.device.name_ids) if self.args.device.multi: net = torch.nn.DataParallel(net) checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,task_id))) net.load_state_dict(checkpoint['model_state_dict']) net = net.to(device=self.args.device.name) # net = self.network._CustomDataParallel(net, self.args.device.name_ids) return net def load_singlehead_model(self,current_model_id): return self.model def load_multihead_model(self, test_id, current_model_id): # Load a previous model old_model=self.network.Net(self.args) if self.args.device.multi: old_model = torch.nn.DataParallel(old_model) checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,test_id))) old_model.load_state_dict(checkpoint['model_state_dict']) # Load a current model current_model=self.network.Net(self.args) if self.args.device.multi: current_model = torch.nn.DataParallel(current_model) checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,current_model_id))) current_model.load_state_dict(checkpoint['model_state_dict']) # Change the current_model head with the old head if self.args.device.multi: old_head=deepcopy(old_model.module.head.state_dict()) current_model.module.head.load_state_dict(old_head) else: old_head = deepcopy(old_model.head.state_dict()) current_model.head.load_state_dict(old_head) current_model=current_model.to(self.args.device.name) return current_model def load_checkpoint(self, task_id): print("Loading checkpoint for task {} ...".format(task_id)) # Load a previous model net=self.network.Net(self.args) net = net.to(device=self.args.device.name) checkpoint=torch.load(os.path.join(self.checkpoint, 'model_run_id_{}_task_id_{}.pth.tar'.format(self.args.seed,task_id))) net.load_state_dict(checkpoint['model_state_dict']) net=net.to(device=self.args.device.name) return net def check_saliency_loaders(self, task_id): path = os.path.join(self.args.path.checkpoint, 'memory') for idx, (images, targets, saliencies, tt, preds) in enumerate(self.saliency_loaders): # loop over batch for i in range(len(images)): fig = plt.figure(dpi=100) # fig.subplots_adjust(hspace=0.01, wspace=0.01) img = images[i].unsqueeze(0) # [1, 3, 224, 224] saliency = saliencies[i].unsqueeze(0) # shape: [1, 1, 7, 7] pred = preds[i] # shape: [1, 7, 7] target = targets[i] # task = tt[i] # saliency = saliency.unsqueeze(0) # shape: image_size = img.shape[2:] # [32x32] saliency = F.interpolate(saliency, size=image_size, mode="bilinear", align_corners=False) B, C, H, W = saliency.shape saliency = saliency.view(B, -1) saliency_max = saliency.max(dim=1, keepdim=True)[0] saliency_max[torch.where(saliency_max == 0)] = 1. # prevent divide by 0 saliency -= saliency.min(dim=1, keepdim=True)[0] saliency /= saliency_max saliency = saliency.view(B, C, H, W) saliency = saliency.squeeze(0).squeeze(0) saliency = saliency.detach().cpu().numpy() img = img.squeeze(0) img = img.cpu().numpy().transpose((1, 2, 0)) # (224, 224, 3) mean = np.array(self.dataset.mean) std = np.array(self.dataset.std) img = std * img + mean img = np.clip(img, 0, 1) # plt.subplot(1, 10, idx + 1) plt.axis('off') result = 'Correct' if pred.item() == target else "Wrong" plt.title('{} | Pred:{}, Truth:{}, Task:{}'.format(result, pred.item(), target, task), fontsize=9) plt.imshow(img) plt.savefig(os.path.join(path, 'Img-batch-{}-ID-{}-task-{}.png'.format(idx, i, task_id))) plt.imshow(saliency, cmap='jet', alpha=0.5) plt.colorbar(fraction=0.046, pad=0.04) fig.tight_layout() plt.savefig(os.path.join(path, 'Sal-batch-{}-ID-{}-task-{}.png'.format(idx, i, task_id))) plt.close() def visualize(self, images, saliencies, task_id, preds, y, path): # fig = plt.figure(figsize=(10, 2), dpi=500) # fig.subplots_adjust(hspace=0.01, wspace=0.01) for idx in range(len(saliencies)): saliency, img = saliencies[idx], images[idx] # img = img.unsqueeze(0) # [1, 3, 224, 224] saliency = saliency.unsqueeze(0) # shape: [1, 7, 7] # saliency = saliency.unsqueeze(0) # shape: [1, 1, 7, 7] image_size = img.shape[2:] # print (saliency.size(), img.size(), image_size) saliency = F.interpolate(saliency, size=image_size, mode="bilinear", align_corners=False) B, C, H, W = saliency.shape saliency = saliency.view(B, -1) saliency_max = saliency.max(dim=1, keepdim=True)[0] saliency_max[torch.where(saliency_max == 0)] = 1. # prevent divide by 0 saliency -= saliency.min(dim=1, keepdim=True)[0] saliency /= saliency_max saliency = saliency.view(B, C, H, W) saliency = saliency.squeeze(0).squeeze(0) saliency = saliency.detach().cpu().numpy() img = img.squeeze(0) img = img.cpu().numpy().transpose((1, 2, 0)) # (224, 224, 3) mean = np.array(self.dataset.mean) std = np.array(self.dataset.std) img = std * img + mean img = np.clip(img, 0, 1) fig = plt.figure(dpi=500) # plt.subplot(1,1, 1) plt.axis('off') result = 'Correct' if preds[idx].item() == y[idx].item() else "Wrong" plt.title('{}-Pred:{},Truth:{}'.format(result, preds[idx].item(), y[idx].item()), fontsize=7) plt.imshow(img) plt.imshow(saliency, cmap='jet', alpha=0.5) plt.colorbar(fraction=0.046, pad=0.04) fig.tight_layout() print('saving to: {}'.format(os.path.join(path, 'sal-task-{}-mem-{}.png'.format(task_id, idx)))) plt.savefig(os.path.join(path, 'sal-task-{}-mem-{}.png'.format(task_id, idx))) plt.close() def compute_memory(self): ssize = 1 print ("***"*200) print ("saliency_size", self.saliency_size) for s in self.saliency_size: ssize *= s saliency_memory = 4 * self.args.experiment.memory_budget * ssize ncha, size, _ = self.args.inputsize image_size = ncha * size * size samples_memory = 4 * self.args.experiment.memory_budget * image_size count = sum(p.numel() for p in self.model.parameters() if p.requires_grad) print('Num parameters in the entire model = %s ' % (utils.human_format(count))) architecture_memory = 4 * count print("--------------------------> Saliency memory size: (%sB)" % utils.human_format(saliency_memory)) print("--------------------------> Episodic memory size: (%sB)" % utils.human_format(samples_memory)) print("------------------------------------------------------------------------------") print(" TOTAL: %sB" % utils.human_format( architecture_memory+samples_memory+saliency_memory))