def main(config): setup_logging(os.getcwd()) logger = logging.getLogger('test') use_gpu = config['n_gpu'] > 0 and torch.cuda.is_available() device = torch.device('cuda:0' if use_gpu else 'cpu') datamanager = DataManger(config['data'], phase='test') model = Baseline(num_classes=datamanager.datasource.get_num_classes( 'train'), is_training=False) model = model.eval() logger.info('Loading checkpoint: {} ...'.format(config['resume'])) checkpoint = torch.load(config['resume'], map_location='cpu') model.load_state_dict(checkpoint['state_dict']) if config['extract']: logger.info('Extract feature from query set...') query_feature, query_label = feature_extractor( model, datamanager.get_dataloader('query'), device) logger.info('Extract feature from gallery set...') gallery_feature, gallery_label = feature_extractor( model, datamanager.get_dataloader('gallery'), device) gallery_embeddings = (gallery_feature, gallery_label) query_embeddings = (query_feature, query_label) os.makedirs(config['testing']['ouput_dir'], exist_ok=True) with open(os.path.join(config['testing']['ouput_dir'], 'gallery_embeddings.pt'), 'wb') as f: torch.save(gallery_embeddings, f) with open(os.path.join(config['testing']['ouput_dir'], 'query_embeddings.pt'), 'wb') as f: torch.save(query_embeddings, f) gallery_feature, gallery_label = torch.load(os.path.join( config['testing']['ouput_dir'], 'gallery_embeddings.pt'), map_location='cpu') query_feature, query_label = torch.load(os.path.join( config['testing']['ouput_dir'], 'query_embeddings.pt'), map_location='cpu') distance = compute_distance_matrix(query_feature, gallery_feature) top1 = top_k(distance, output=gallery_label, target=query_label, k=1) top5 = top_k(distance, output=gallery_label, target=query_label, k=5) top10 = top_k(distance, output=gallery_label, target=query_label, k=10) m_ap = mAP(distance, output=gallery_label, target=query_label, k='all') logger.info('Datasets: {}, without spatial-temporal: top1: {}, top5: {}, top10: {}, mAP: {}'.format( datamanager.datasource.get_name_dataset(), top1, top5, top10, m_ap))
parser = Flags() parser.set_arguments() FG = parser.parse_args() c_code, axis, z_dim = FG.c_code, FG.axis, FG.z_dim device = torch.device(FG.devices[0]) torch.cuda.set_device(FG.devices[0]) nets = [] for i in range(FG.fold): parser.configure('cur_fold', i) parser.configure('ckpt_dir') FG = parser.load() net = Baseline(FG.ckpt_dir, len(FG.labels)) net.to(device) net.load(epoch=None, optimizer=None, is_best=True) net.eval() nets += [net] #G = Generator(FG) G = torch.nn.DataParallel(Generator(z_dim, c_code, axis)) # state_dict = torch.load(os.path.join('BiGAN-info-c4-f', 'G.pth'), 'cpu') state_dict = torch.load(os.path.join('157-G8', 'G.pth'), 'cpu') G.load_state_dict(state_dict) G.to(device) G.eval() if axis == 1: ns = ((160, 160), (112, 112)) elif axis == 0: ns = ((192, 160), (144, 112))
def main(args): # Set up logging args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False) log = util.get_logger(args.save_dir, args.name) log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}') device, gpu_ids = util.get_available_devices() args.batch_size *= max(1, len(gpu_ids)) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vectors = util.torch_from_json(args.char_emb_file) # Get model log.info('Building model...') nbr_model = 0 if (args.load_path_baseline): model_baseline = Baseline(word_vectors=word_vectors, hidden_size=100) model_baseline = nn.DataParallel(model_baseline, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_baseline}...') model_baseline = util.load_model(model_baseline, args.load_path_baseline, gpu_ids, return_step=False) model_baseline = model_baseline.to(device) model_baseline.eval() nll_meter_baseline = util.AverageMeter() nbr_model += 1 save_prob_baseline_start = [] save_prob_baseline_end = [] if (args.load_path_bidaf): model_bidaf = BiDAF(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size) model_bidaf = nn.DataParallel(model_bidaf, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_bidaf}...') model_bidaf = util.load_model(model_bidaf, args.load_path_bidaf, gpu_ids, return_step=False) model_bidaf = model_bidaf.to(device) model_bidaf.eval() nll_meter_bidaf = util.AverageMeter() nbr_model += 1 save_prob_bidaf_start = [] save_prob_bidaf_end = [] if (args.load_path_bidaf_fusion): model_bidaf_fu = BiDAF_fus(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size) model_bidaf_fu = nn.DataParallel(model_bidaf_fu, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_bidaf_fusion}...') model_bidaf_fu = util.load_model(model_bidaf_fu, args.load_path_bidaf_fusion, gpu_ids, return_step=False) model_bidaf_fu = model_bidaf_fu.to(device) model_bidaf_fu.eval() nll_meter_bidaf_fu = util.AverageMeter() nbr_model += 1 save_prob_bidaf_fu_start = [] save_prob_bidaf_fu_end = [] if (args.load_path_qanet): model_qanet = QANet(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks, divisor_dim_kqv=args.divisor_dim_kqv) model_qanet = nn.DataParallel(model_qanet, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet}...') model_qanet = util.load_model(model_qanet, args.load_path_qanet, gpu_ids, return_step=False) model_qanet = model_qanet.to(device) model_qanet.eval() nll_meter_qanet = util.AverageMeter() nbr_model += 1 save_prob_qanet_start = [] save_prob_qanet_end = [] if (args.load_path_qanet_old): model_qanet_old = QANet_old(word_vectors=word_vectors, char_vectors=char_vectors, device=device, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks) model_qanet_old = nn.DataParallel(model_qanet_old, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet_old}...') model_qanet_old = util.load_model(model_qanet_old, args.load_path_qanet_old, gpu_ids, return_step=False) model_qanet_old = model_qanet_old.to(device) model_qanet_old.eval() nll_meter_qanet_old = util.AverageMeter() nbr_model += 1 save_prob_qanet_old_start = [] save_prob_qanet_old_end = [] if (args.load_path_qanet_inde): model_qanet_inde = QANet_independant_encoder( word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks, divisor_dim_kqv=args.divisor_dim_kqv) model_qanet_inde = nn.DataParallel(model_qanet_inde, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet_inde}...') model_qanet_inde = util.load_model(model_qanet_inde, args.load_path_qanet_inde, gpu_ids, return_step=False) model_qanet_inde = model_qanet_inde.to(device) model_qanet_inde.eval() nll_meter_qanet_inde = util.AverageMeter() nbr_model += 1 save_prob_qanet_inde_start = [] save_prob_qanet_inde_end = [] if (args.load_path_qanet_s_e): model_qanet_s_e = QANet_S_E(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks, divisor_dim_kqv=args.divisor_dim_kqv) model_qanet_s_e = nn.DataParallel(model_qanet_s_e, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet_s_e}...') model_qanet_s_e = util.load_model(model_qanet_s_e, args.load_path_qanet_s_e, gpu_ids, return_step=False) model_qanet_s_e = model_qanet_s_e.to(device) model_qanet_s_e.eval() nll_meter_qanet_s_e = util.AverageMeter() nbr_model += 1 save_prob_qanet_s_e_start = [] save_prob_qanet_s_e_end = [] # Get data loader log.info('Building dataset...') record_file = vars(args)[f'{args.split}_record_file'] dataset = SQuAD(record_file, args.use_squad_v2) data_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Evaluate log.info(f'Evaluating on {args.split} split...') pred_dict = {} # Predictions for TensorBoard sub_dict = {} # Predictions for submission eval_file = vars(args)[f'{args.split}_eval_file'] with open(eval_file, 'r') as fh: gold_dict = json_load(fh) with torch.no_grad(), \ tqdm(total=len(dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) cc_idxs = cc_idxs.to(device) qc_idxs = qc_idxs.to(device) batch_size = cw_idxs.size(0) y1, y2 = y1.to(device), y2.to(device) l_p1, l_p2 = [], [] # Forward if (args.load_path_baseline): log_p1_baseline, log_p2_baseline = model_baseline( cw_idxs, cc_idxs) loss_baseline = F.nll_loss(log_p1_baseline, y1) + F.nll_loss( log_p2_baseline, y2) nll_meter_baseline.update(loss_baseline.item(), batch_size) l_p1 += [log_p1_baseline.exp()] l_p2 += [log_p2_baseline.exp()] if (args.save_probabilities): save_prob_baseline_start += [ log_p1_baseline.exp().detach().cpu().numpy() ] save_prob_baseline_end += [ log_p2_baseline.exp().detach().cpu().numpy() ] if (args.load_path_qanet): log_p1_qanet, log_p2_qanet = model_qanet( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet = F.nll_loss(log_p1_qanet, y1) + F.nll_loss( log_p2_qanet, y2) nll_meter_qanet.update(loss_qanet.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet.exp()] l_p2 += [log_p2_qanet.exp()] if (args.save_probabilities): save_prob_qanet_start += [ log_p1_qanet.exp().detach().cpu().numpy() ] save_prob_qanet_end += [ log_p2_qanet.exp().detach().cpu().numpy() ] if (args.load_path_qanet_old): log_p1_qanet_old, log_p2_qanet_old = model_qanet_old( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet_old = F.nll_loss(log_p1_qanet_old, y1) + F.nll_loss( log_p2_qanet_old, y2) nll_meter_qanet_old.update(loss_qanet_old.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet_old.exp()] l_p2 += [log_p2_qanet_old.exp()] if (args.save_probabilities): save_prob_qanet_old_start += [ log_p1_qanet_old.exp().detach().cpu().numpy() ] save_prob_qanet_old_end += [ log_p2_qanet_old.exp().detach().cpu().numpy() ] if (args.load_path_qanet_inde): log_p1_qanet_inde, log_p2_qanet_inde = model_qanet_inde( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet_inde = F.nll_loss( log_p1_qanet_inde, y1) + F.nll_loss(log_p2_qanet_inde, y2) nll_meter_qanet_inde.update(loss_qanet_inde.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet_inde.exp()] l_p2 += [log_p2_qanet_inde.exp()] if (args.save_probabilities): save_prob_qanet_inde_start += [ log_p1_qanet_inde.exp().detach().cpu().numpy() ] save_prob_qanet_inde_end += [ log_p2_qanet_inde.exp().detach().cpu().numpy() ] if (args.load_path_qanet_s_e): log_p1_qanet_s_e, log_p2_qanet_s_e = model_qanet_s_e( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet_s_e = F.nll_loss(log_p1_qanet_s_e, y1) + F.nll_loss( log_p2_qanet_s_e, y2) nll_meter_qanet_s_e.update(loss_qanet_s_e.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet_s_e.exp()] l_p2 += [log_p2_qanet_s_e.exp()] if (args.save_probabilities): save_prob_qanet_s_e_start += [ log_p1_qanet_s_e.exp().detach().cpu().numpy() ] save_prob_qanet_s_e_end += [ log_p2_qanet_s_e.exp().detach().cpu().numpy() ] if (args.load_path_bidaf): log_p1_bidaf, log_p2_bidaf = model_bidaf( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_bidaf = F.nll_loss(log_p1_bidaf, y1) + F.nll_loss( log_p2_bidaf, y2) nll_meter_bidaf.update(loss_bidaf.item(), batch_size) l_p1 += [log_p1_bidaf.exp()] l_p2 += [log_p2_bidaf.exp()] if (args.save_probabilities): save_prob_bidaf_start += [ log_p1_bidaf.exp().detach().cpu().numpy() ] save_prob_bidaf_end += [ log_p2_bidaf.exp().detach().cpu().numpy() ] if (args.load_path_bidaf_fusion): log_p1_bidaf_fu, log_p2_bidaf_fu = model_bidaf_fu( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_bidaf_fu = F.nll_loss(log_p1_bidaf_fu, y1) + F.nll_loss( log_p2_bidaf_fu, y2) nll_meter_bidaf_fu.update(loss_bidaf_fu.item(), batch_size) l_p1 += [log_p1_bidaf_fu.exp()] l_p2 += [log_p2_bidaf_fu.exp()] if (args.save_probabilities): save_prob_bidaf_fu_start += [ log_p1_bidaf_fu.exp().detach().cpu().numpy() ] save_prob_bidaf_fu_end += [ log_p2_bidaf_fu.exp().detach().cpu().numpy() ] p1, p2 = l_p1[0], l_p2[0] for i in range(1, nbr_model): p1 += l_p1[i] p2 += l_p2[i] p1 /= nbr_model p2 /= nbr_model starts, ends = util.discretize(p1, p2, args.max_ans_len, args.use_squad_v2) # Log info progress_bar.update(batch_size) if args.split != 'test': # No labels for the test set, so NLL would be invalid if (args.load_path_qanet): progress_bar.set_postfix(NLL=nll_meter_qanet.avg) elif (args.load_path_bidaf): progress_bar.set_postfix(NLL=nll_meter_bidaf.avg) elif (args.load_path_bidaf_fusion): progress_bar.set_postfix(NLL=nll_meter_bidaf_fu.avg) elif (args.load_path_qanet_old): progress_bar.set_postfix(NLL=nll_meter_qanet_old.avg) elif (args.load_path_qanet_inde): progress_bar.set_postfix(NLL=nll_meter_qanet_inde.avg) elif (args.load_path_qanet_s_e): progress_bar.set_postfix(NLL=nll_meter_qanet_s_e.avg) else: progress_bar.set_postfix(NLL=nll_meter_baseline.avg) idx2pred, uuid2pred = util.convert_tokens(gold_dict, ids.tolist(), starts.tolist(), ends.tolist(), args.use_squad_v2) pred_dict.update(idx2pred) sub_dict.update(uuid2pred) if (args.save_probabilities): if (args.load_path_baseline): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_baseline_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_baseline_end, fp) if (args.load_path_bidaf): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_end, fp) if (args.load_path_bidaf_fusion): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_fu_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_fu_end, fp) if (args.load_path_qanet): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_end, fp) if (args.load_path_qanet_old): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_old_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_old_end, fp) if (args.load_path_qanet_inde): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_inde_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_inde_end, fp) if (args.load_path_qanet_s_e): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_s_e_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_s_e_end, fp) # Log results (except for test set, since it does not come with labels) if args.split != 'test': results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2) if (args.load_path_qanet): meter_avg = nll_meter_qanet.avg elif (args.load_path_bidaf): meter_avg = nll_meter_bidaf.avg elif (args.load_path_bidaf_fusion): meter_avg = nll_meter_bidaf_fu.avg elif (args.load_path_qanet_inde): meter_avg = nll_meter_qanet_inde.avg elif (args.load_path_qanet_s_e): meter_avg = nll_meter_qanet_s_e.avg elif (args.load_path_qanet_old): meter_avg = nll_meter_qanet_old.avg else: meter_avg = nll_meter_baseline.avg results_list = [('NLL', meter_avg), ('F1', results['F1']), ('EM', results['EM'])] if args.use_squad_v2: results_list.append(('AvNA', results['AvNA'])) results = OrderedDict(results_list) # Log to console results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items()) log.info(f'{args.split.title()} {results_str}') # Log to TensorBoard tbx = SummaryWriter(args.save_dir) util.visualize(tbx, pred_dict=pred_dict, eval_path=eval_file, step=0, split=args.split, num_visuals=args.num_visuals) # Write submission file sub_path = join(args.save_dir, args.split + '_' + args.sub_file) log.info(f'Writing submission file to {sub_path}...') with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh: csv_writer = csv.writer(csv_fh, delimiter=',') csv_writer.writerow(['Id', 'Predicted']) for uuid in sorted(sub_dict): csv_writer.writerow([uuid, sub_dict[uuid]])
class Trainer(BaseTrainer): def __init__(self, config): super(Trainer, self).__init__(config) self.datamanager = DataManger(config["data"]) # model self.model = Baseline( num_classes=self.datamanager.datasource.get_num_classes("train") ) # summary model summary( self.model, input_size=(3, 256, 128), batch_size=config["data"]["batch_size"], device="cpu", ) # losses cfg_losses = config["losses"] self.criterion = Softmax_Triplet_loss( num_class=self.datamanager.datasource.get_num_classes("train"), margin=cfg_losses["margin"], epsilon=cfg_losses["epsilon"], use_gpu=self.use_gpu, ) self.center_loss = CenterLoss( num_classes=self.datamanager.datasource.get_num_classes("train"), feature_dim=2048, use_gpu=self.use_gpu, ) # optimizer cfg_optimizer = config["optimizer"] self.optimizer = torch.optim.Adam( self.model.parameters(), lr=cfg_optimizer["lr"], weight_decay=cfg_optimizer["weight_decay"], ) self.optimizer_centerloss = torch.optim.SGD( self.center_loss.parameters(), lr=0.5 ) # learing rate scheduler cfg_lr_scheduler = config["lr_scheduler"] self.lr_scheduler = WarmupMultiStepLR( self.optimizer, milestones=cfg_lr_scheduler["steps"], gamma=cfg_lr_scheduler["gamma"], warmup_factor=cfg_lr_scheduler["factor"], warmup_iters=cfg_lr_scheduler["iters"], warmup_method=cfg_lr_scheduler["method"], ) # track metric self.train_metrics = MetricTracker("loss", "accuracy") self.valid_metrics = MetricTracker("loss", "accuracy") # save best accuracy for function _save_checkpoint self.best_accuracy = None # send model to device self.model.to(self.device) self.scaler = GradScaler() # resume model from last checkpoint if config["resume"] != "": self._resume_checkpoint(config["resume"]) def train(self): for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) if self.lr_scheduler is not None: self.lr_scheduler.step() result = self._valid_epoch(epoch) # add scalars to tensorboard self.writer.add_scalars( "Loss", { "Train": self.train_metrics.avg("loss"), "Val": self.valid_metrics.avg("loss"), }, global_step=epoch, ) self.writer.add_scalars( "Accuracy", { "Train": self.train_metrics.avg("accuracy"), "Val": self.valid_metrics.avg("accuracy"), }, global_step=epoch, ) # logging result to console log = {"epoch": epoch} log.update(result) for key, value in log.items(): self.logger.info(" {:15s}: {}".format(str(key), value)) # save model if ( self.best_accuracy == None or self.best_accuracy < self.valid_metrics.avg("accuracy") ): self.best_accuracy = self.valid_metrics.avg("accuracy") self._save_checkpoint(epoch, save_best=True) else: self._save_checkpoint(epoch, save_best=False) # save logs self._save_logs(epoch) def _train_epoch(self, epoch): """Training step""" self.model.train() self.train_metrics.reset() with tqdm(total=len(self.datamanager.get_dataloader("train"))) as epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch}") for batch_idx, (data, labels, _) in enumerate( self.datamanager.get_dataloader("train") ): # push data to device data, labels = data.to(self.device), labels.to(self.device) # zero gradient self.optimizer.zero_grad() self.optimizer_centerloss.zero_grad() with autocast(): # forward batch score, feat = self.model(data) # calculate loss and accuracy loss = ( self.criterion(score, feat, labels) + self.center_loss(feat, labels) * self.config["losses"]["beta"] ) _, preds = torch.max(score.data, dim=1) # backward parameters # loss.backward() self.scaler.scale(loss).backward() # backward parameters for center_loss for param in self.center_loss.parameters(): param.grad.data *= 1.0 / self.config["losses"]["beta"] # optimize # self.optimizer.step() self.scaler.step(self.optimizer) self.optimizer_centerloss.step() self.scaler.update() # update loss and accuracy in MetricTracker self.train_metrics.update("loss", loss.item()) self.train_metrics.update( "accuracy", torch.sum(preds == labels.data).double().item() / data.size(0), ) # update process bar epoch_pbar.set_postfix( { "train_loss": self.train_metrics.avg("loss"), "train_acc": self.train_metrics.avg("accuracy"), } ) epoch_pbar.update(1) return self.train_metrics.result() def _valid_epoch(self, epoch): """Validation step""" self.model.eval() self.valid_metrics.reset() with torch.no_grad(): with tqdm(total=len(self.datamanager.get_dataloader("val"))) as epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch}") for batch_idx, (data, labels, _) in enumerate( self.datamanager.get_dataloader("val") ): # push data to device data, labels = data.to(self.device), labels.to(self.device) with autocast(): # forward batch score, feat = self.model(data) # calculate loss and accuracy loss = ( self.criterion(score, feat, labels) + self.center_loss(feat, labels) * self.config["losses"]["beta"] ) _, preds = torch.max(score.data, dim=1) # update loss and accuracy in MetricTracker self.valid_metrics.update("loss", loss.item()) self.valid_metrics.update( "accuracy", torch.sum(preds == labels.data).double().item() / data.size(0), ) # update process bar epoch_pbar.set_postfix( { "val_loss": self.valid_metrics.avg("loss"), "val_acc": self.valid_metrics.avg("accuracy"), } ) epoch_pbar.update(1) return self.valid_metrics.result() def _save_checkpoint(self, epoch, save_best=True): """save model to file""" state = { "epoch": epoch, "state_dict": self.model.state_dict(), "center_loss": self.center_loss.state_dict(), "optimizer": self.optimizer.state_dict(), "optimizer_centerloss": self.optimizer_centerloss.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "best_accuracy": self.best_accuracy, } filename = os.path.join(self.checkpoint_dir, "model_last.pth") self.logger.info("Saving last model: model_last.pth ...") torch.save(state, filename) if save_best: filename = os.path.join(self.checkpoint_dir, "model_best.pth") self.logger.info("Saving current best: model_best.pth ...") torch.save(state, filename) def _resume_checkpoint(self, resume_path): """Load model from checkpoint""" if not os.path.exists(resume_path): raise FileExistsError("Resume path not exist!") self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path, map_location=self.map_location) self.start_epoch = checkpoint["epoch"] + 1 self.model.load_state_dict(checkpoint["state_dict"]) self.center_loss.load_state_dict(checkpoint["center_loss"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.optimizer_centerloss.load_state_dict(checkpoint["optimizer_centerloss"]) self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) self.best_accuracy = checkpoint["best_accuracy"] self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch) ) def _save_logs(self, epoch): """Save logs from google colab to google drive""" if os.path.isdir(self.logs_dir_saved): shutil.rmtree(self.logs_dir_saved) destination = shutil.copytree(self.logs_dir, self.logs_dir_saved)