def main(args): path = os.path.join(os.getcwd(), 'soft_label', 'soft_label_resnet50.txt') if not os.path.isfile(path): print('soft label file is not exist') train_loader = getTrainLoader(args, path) _, val_loader, num_query, num_classes, train_size = make_data_loader(args) #train_loader, val_loader, num_query, num_classes, train_size = make_data_loader(args) model = build_model(args, num_classes) optimizer = make_optimizer(args, model) scheduler = WarmupMultiStepLR(optimizer, [30, 55], 0.1, 0.01, 5, "linear") loss_func = make_loss(args) model.to(device) for epoch in range(args.Epochs): model.train() running_loss = 0.0 running_klloss = 0.0 running_softloss = 0.0 running_corrects = 0.0 for index, data in enumerate(tqdm(train_loader)): img, target, soft_target = data img = img.cuda() target = target.cuda() soft_target = soft_target.cuda() score, _ = model(img) preds = torch.max(score.data, 1)[1] loss, klloss, softloss = loss_func(score, target, soft_target) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() running_klloss += klloss.item() running_softloss += softloss.item() running_corrects += float(torch.sum(preds == target.data)) scheduler.step() epoch_loss = running_loss / train_size epoch_klloss = running_klloss / train_size epoch_softloss = running_softloss / train_size epoch_acc = running_corrects / train_size print( "Epoch {} Loss : {:.4f} KLLoss:{:.8f} SoftLoss:{:.4f} Acc:{:.4f}" .format(epoch, epoch_loss, epoch_klloss, epoch_softloss, epoch_acc)) if (epoch + 1) % args.n_save == 0: evaluator = Evaluator(model, val_loader, num_query) cmc, mAP = evaluator.run() print('---------------------------') print("CMC Curve:") for r in [1, 5, 10]: print("Rank-{} : {:.1%}".format(r, cmc[r - 1])) print("mAP : {:.1%}".format(mAP)) print('---------------------------') save_model(args, model, optimizer, epoch)
def main(args): sys.stdout = Logger( os.path.join(args.log_path, args.log_description, 'log' + time.strftime(".%m_%d_%H:%M:%S") + '.txt')) train_loader, val_loader, num_query, num_classes, train_size = make_data_loader( args) model = build_model(args, num_classes) print(model) optimizer = make_optimizer(args, model) scheduler = WarmupMultiStepLR(optimizer, [30, 55], 0.1, 0.01, 5, "linear") loss_func = make_loss(args) model.to(device) for epoch in range(args.Epochs): model.train() running_loss = 0.0 running_corrects = 0.0 for index, data in enumerate(tqdm(train_loader)): img, target = data img = img.cuda() target = target.cuda() score, _ = model(img) preds = torch.max(score.data, 1)[1] loss = loss_func(score, target) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() running_corrects += float(torch.sum(preds == target.data)) scheduler.step() epoch_loss = running_loss / train_size epoch_acc = running_corrects / train_size print("Epoch {} Loss : {:.6f} Acc:{:.4f}".format( epoch, epoch_loss, epoch_acc)) if (epoch + 1) % args.n_save == 0: evaluator = Evaluator(model, val_loader, num_query) cmc, mAP = evaluator.run() print('---------------------------') print("CMC Curve:") for r in [1, 5, 10]: print("Rank-{} : {:.1%}".format(r, cmc[r - 1])) print("mAP : {:.1%}".format(mAP)) print('---------------------------') save_model(args, model, optimizer, epoch)
def main(args): # transform normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) val_transform = transforms.Compose([ transforms.ToTensor(), normalize ]) test_transform = transforms.Compose([ transforms.ToTensor(), normalize ]) cap_transform = None # data train_loader = data_config(args.image_dir, args.anno_dir, args.batch_size, 'train', 100, train_transform, cap_transform=cap_transform) test_loader = data_config(args.image_dir, args.anno_dir, 64, 'test', 100, test_transform) unique_image = get_image_unique(args.image_dir, args.anno_dir, 64, 'test', 100, test_transform) # loss compute_loss = Loss(args) nn.DataParallel(compute_loss).cuda() # network network, optimizer = network_config(args, 'train', compute_loss.parameters(), args.resume, args.model_path) # lr_scheduler scheduler = WarmupMultiStepLR(optimizer, (20, 25, 35), 0.1, 0.01, 10, 'linear') ac_t2i_top1_best = 0.0 best_epoch = 0 for epoch in range(args.num_epoches - args.start_epoch): network.train() # train for one epoch train_loss, train_time, image_precision, text_precision = train(args.start_epoch + epoch, train_loader, network, optimizer, compute_loss, args) # evaluate on validation set is_best = False print('Train done for epoch-{}'.format(args.start_epoch + epoch)) logging.info('Epoch: [{}|{}], train_time: {:.3f}, train_loss: {:.3f}'.format(args.start_epoch + epoch, args.num_epoches, train_time, train_loss)) logging.info('image_precision: {:.3f}, text_precision: {:.3f}'.format(image_precision, text_precision)) scheduler.step() for param in optimizer.param_groups: print('lr:{}'.format(param['lr'])) if epoch >= 0: ac_top1_i2t, ac_top5_i2t, ac_top10_i2t, ac_top1_t2i, ac_top5_t2i , ac_top10_t2i, test_time = test(test_loader, network, args, unique_image) state = {'network': network.state_dict(), 'optimizer': optimizer.state_dict(), 'W': compute_loss.W, 'epoch': args.start_epoch + epoch} if ac_top1_t2i > ac_t2i_top1_best: best_epoch = epoch ac_t2i_top1_best = ac_top1_t2i save_checkpoint(state, epoch, args.checkpoint_dir, is_best) logging.info('epoch:{}'.format(epoch)) logging.info('top1_t2i: {:.3f}, top5_t2i: {:.3f}, top10_t2i: {:.3f}, top1_i2t: {:.3f}, top5_i2t: {:.3f}, top10_i2t: {:.3f}'.format( ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, ac_top1_i2t, ac_top5_i2t, ac_top10_i2t)) logging.info('Best epoch:{}'.format(best_epoch)) logging.info('Train done') logging.info(args.checkpoint_dir) logging.info(args.log_dir)
class BaseModel(object): def __init__(self, cfg): self.cfg = cfg self._init_models() self._init_optimizers() print('---------- Networks initialized -------------') print_network(self.Content_Encoder) print('-----------------------------------------------') def _init_models(self): # -----------------Content_Encoder------------------- self.Content_Encoder = Baseline(self.cfg.DATASETS.NUM_CLASSES_S, 1, self.cfg.MODEL.PRETRAIN_PATH, 'bnneck', 'after', self.cfg.MODEL.NAME, 'imagenet') # -----------------Criterion----------------- # self.xent = CrossEntropyLabelSmooth(num_classes=self.cfg.DATASETS.NUM_CLASSES_S).cuda() self.triplet = TripletLoss(0.3) self.Smooth_L1_loss = torch.nn.SmoothL1Loss(reduction='mean').cuda() # --------------------Cuda------------------- # self.Content_Encoder = torch.nn.DataParallel(self.Content_Encoder).cuda() def _init_optimizers(self): self.Content_optimizer = make_optimizer(self.cfg, self.Content_Encoder) self.Content_optimizer_fix = make_optimizer(self.cfg, self.Content_Encoder, fix=True) self.scheduler = WarmupMultiStepLR(self.Content_optimizer, (30, 55), 0.1, 1.0 / 3, 500, "linear") self.scheduler_fix = WarmupMultiStepLR(self.Content_optimizer_fix, (30, 55), 0.1, 1.0 / 3, 500, "linear") self.schedulers = [] self.optimizers = [] def reset_model_status(self): self.Content_Encoder.train() def two_classifier(self, epoch, train_loader_s, train_loader_t, writer, logger, rand_src_1, rand_src_2, print_freq=1): self.reset_model_status() self.epoch = epoch self.scheduler.step(epoch) self.scheduler_fix.step(epoch) target_iter = iter(train_loader_t) batch_time = AverageMeter() data_time = AverageMeter() end = time.time() if (epoch < 80) or (110 <= epoch < 170): mode = 'normal_c1_c2' elif (80 <= epoch < 110) or (170 <= epoch < 210): mode = 'reverse_c1_c2' elif 210 <= epoch: mode = 'fix_c1_c2' for i, inputs in enumerate(train_loader_s): data_time.update(time.time() - end) try: inputs_target = next(target_iter) except: target_iter = iter(train_loader_t) inputs_target = next(target_iter) img_s, pid_s, camid_s = self._parse_data(inputs) img_t, pid_t, camid_t = self._parse_data(inputs_target) content_code_s, content_feat_s = self.Content_Encoder(img_s) pid_s_12 = np.asarray(pid_s.cpu()) camid_s = np.asarray(camid_s.cpu()) idx = [] for c_id in rand_src_1: if len(np.where(c_id == camid_s)[0]) == 0: continue else: idx.append(np.where(c_id == camid_s)[0]) if idx == [] or len(idx[0]) == 1: idx = [np.asarray([a]) for a in range(self.cfg.SOLVER.IMS_PER_BATCH)] idx = np.concatenate(idx) pid_1 = torch.tensor(pid_s_12[idx]).cuda() feat_1 = content_feat_s[idx] idx = [] for c_id in rand_src_2: if len(np.where(c_id == camid_s)[0]) == 0: continue else: idx.append(np.where(c_id == camid_s)[0]) if idx == [] or len(idx[0]) == 1: idx = [np.asarray([a]) for a in range(self.cfg.SOLVER.IMS_PER_BATCH)] idx = np.concatenate(idx) pid_2 = torch.tensor(pid_s_12[idx]).cuda() feat_2 = content_feat_s[idx] if mode == 'normal_c1_c2': class_1 = self.Content_Encoder(feat_1, mode='c1') class_2 = self.Content_Encoder(feat_2, mode='c2') ID_loss_1 = self.xent(class_1, pid_1) ID_loss_2 = self.xent(class_2, pid_2) ID_tri_loss = self.triplet(content_feat_s, pid_s) total_loss = ID_loss_1 + ID_loss_2 + ID_tri_loss[0] self.Content_optimizer.zero_grad() total_loss.backward() self.Content_optimizer.step() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: logger.info('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'ID_loss: {:.3f} ID_loss_1: {:.3f} ID_loss_2: {:.3f} tri_loss: {:.3f} ' .format(epoch, i + 1, len(train_loader_s), batch_time.val, batch_time.avg, data_time.val, data_time.avg, total_loss.item(), ID_loss_1.item(), ID_loss_2.item(), ID_tri_loss[0].item() )) elif mode == 'reverse_c1_c2': class_1 = self.Content_Encoder(feat_1, mode='c2') class_2 = self.Content_Encoder(feat_2, mode='c1') ID_loss_1 = self.xent(class_1, pid_1) ID_loss_2 = self.xent(class_2, pid_2) ID_tri_loss = self.triplet(content_feat_s, pid_s) total_loss = ID_loss_1 + ID_loss_2 + ID_tri_loss[0] self.Content_optimizer_fix.zero_grad() total_loss.backward() self.Content_optimizer_fix.step() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: logger.info('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'ID_loss: {:.3f} ID_loss_1: {:.3f} ID_loss_2: {:.3f} tri_loss: {:.3f}' .format(epoch, i + 1, len(train_loader_s), batch_time.val, batch_time.avg, data_time.val, data_time.avg, total_loss.item(), ID_loss_1.item(), ID_loss_2.item(), ID_tri_loss[0].item() )) elif mode == 'fix_c1_c2': class_1 = self.Content_Encoder(feat_1, mode='c2') class_2 = self.Content_Encoder(feat_2, mode='c1') ID_loss_1 = self.xent(class_1, pid_1) ID_loss_2 = self.xent(class_2, pid_2) content_code_t, content_feat_t = self.Content_Encoder(img_t) tar_class_1 = self.Content_Encoder(content_feat_t, mode='c1') tar_class_2 = self.Content_Encoder(content_feat_t, mode='c2') tar_L1_loss = self.Smooth_L1_loss(tar_class_1, tar_class_2) ID_tri_loss = self.triplet(content_feat_s, pid_s) arg_c1 = torch.argmax(tar_class_1, dim=1) arg_c2 = torch.argmax(tar_class_2, dim=1) arg_idx = [] fake_id = [] for i_dx, data in enumerate(arg_c1): if (data == arg_c2[i_dx]) and (((tar_class_1[i_dx][data] + tar_class_2[i_dx][arg_c2[i_dx]])/2) > 0.8): arg_idx.append(i_dx) fake_id.append(data) if 210 <= epoch < 220: if arg_idx != []: ID_loss_fake = self.xent(content_code_t[arg_idx], torch.tensor(fake_id).cuda()) total_loss = ID_loss_1 + ID_loss_2 + 0.5 * tar_L1_loss + ID_tri_loss[0] else: ID_loss_fake = torch.tensor([0]) total_loss = ID_loss_1 + ID_loss_2 + 0.5 * tar_L1_loss + ID_tri_loss[0] if 220 <= epoch: if arg_idx != []: ID_loss_fake = self.xent(content_code_t[arg_idx], torch.tensor(fake_id).cuda()) total_loss = ID_loss_1 + ID_loss_2 + 0.08 * ID_loss_fake + ID_tri_loss[0] + 0.5 * tar_L1_loss else: ID_loss_fake = torch.tensor([0]) total_loss = ID_loss_1 + ID_loss_2 + ID_tri_loss[0] + 0.5 * tar_L1_loss self.Content_optimizer_fix.zero_grad() total_loss.backward() self.Content_optimizer_fix.step() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: logger.info('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'ID_loss: {:.3f} ID_loss_1: {:.3f} ID_loss_2: {:.3f} tar_L1_loss: {:.3f} tri_loss: {:.3f} ID_loss_fake: {:.6f}' .format(epoch, i + 1, len(train_loader_s), batch_time.val, batch_time.avg, data_time.val, data_time.avg, total_loss.item(), ID_loss_1.item(), ID_loss_2.item(), tar_L1_loss.item(), ID_tri_loss[0].item(), ID_loss_fake.item())) def _parse_data(self, inputs): imgs, pids, camids = inputs inputs = imgs.cuda() targets = pids.cuda() camids = camids.cuda() return inputs, targets, camids
def main(): torch.backends.cudnn.deterministic = True cudnn.benchmark = True #parser = argparse.ArgumentParser(description="ReID Baseline Training") #parser.add_argument( #"--config_file", default="", help="path to config file", type=str) #parser.add_argument("opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) #args = parser.parse_args() config_file = 'configs/baseline_veri_r101_a.yml' if config_file != "": cfg.merge_from_file(config_file) #cfg.merge_from_list(args.opts) cfg.freeze() output_dir = cfg.OUTPUT_DIR if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) logger = setup_logger("reid_baseline", output_dir, if_train=True) logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) logger.info(config_file) if config_file != "": logger.info("Loaded configuration file {}".format(config_file)) with open(config_file, 'r') as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID path = 'D:/Python_SMU/Veri/verigms/gms/' pkl = {} entries = os.listdir(path) for name in entries: f = open((path + name), 'rb') if name == 'featureMatrix.pkl': s = name[0:13] else: s = name[0:3] pkl[s] = pickle.load(f) f.close with open('cids.pkl', 'rb') as handle: b = pickle.load(handle) with open('index.pkl', 'rb') as handle: c = pickle.load(handle) train_transforms, val_transforms, dataset, train_set, val_set = make_dataset( cfg, pkl_file='index.pkl') num_workers = cfg.DATALOADER.NUM_WORKERS num_classes = dataset.num_train_pids #pkl_f = 'index.pkl' pid = 0 pidx = {} for img_path, pid, _, _ in dataset.train: path = img_path.split('\\')[-1] folder = path[1:4] pidx[folder] = pid pid += 1 if 'triplet' in cfg.DATALOADER.SAMPLER: train_loader = DataLoader(train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, sampler=RandomIdentitySampler( dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), num_workers=num_workers, pin_memory=True, collate_fn=train_collate_fn) elif cfg.DATALOADER.SAMPLER == 'softmax': print('using softmax sampler') train_loader = DataLoader(train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, pin_memory=True, collate_fn=train_collate_fn) else: print('unsupported sampler! expected softmax or triplet but got {}'. format(cfg.SAMPLER)) print("train loader loaded successfully") val_loader = DataLoader(val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=train_collate_fn) print("val loader loaded successfully") if cfg.MODEL.PRETRAIN_CHOICE == 'finetune': model = make_model(cfg, num_class=576) model.load_param_finetune(cfg.MODEL.PRETRAIN_PATH) print('Loading pretrained model for finetuning......') else: model = make_model(cfg, num_class=num_classes) loss_func, center_criterion = make_loss(cfg, num_classes=num_classes) optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion) scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, cfg.SOLVER.WARMUP_EPOCHS, cfg.SOLVER.WARMUP_METHOD) print("model,optimizer, loss, scheduler loaded successfully") height, width = cfg.INPUT.SIZE_TRAIN log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD device = "cuda" epochs = cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger("reid_baseline.train") logger.info('start training') if device: if torch.cuda.device_count() > 1: print('Using {} GPUs for training'.format( torch.cuda.device_count())) model = nn.DataParallel(model) model.to(device) loss_meter = AverageMeter() acc_meter = AverageMeter() evaluator = R1_mAP_eval(len(dataset.query), max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) model.base._freeze_stages() logger.info('Freezing the stages number:{}'.format(cfg.MODEL.FROZEN)) data_index = search(pkl) print("Ready for training") for epoch in range(1, epochs + 1): start_time = time.time() loss_meter.reset() acc_meter.reset() evaluator.reset() scheduler.step() model.train() for n_iter, (img, label, index, pid, cid) in enumerate(train_loader): optimizer.zero_grad() optimizer_center.zero_grad() #img = img.to(device) #target = vid.to(device) trainX, trainY = torch.zeros( (train_loader.batch_size * 3, 3, height, width), dtype=torch.float32), torch.zeros( (train_loader.batch_size * 3), dtype=torch.int64) for i in range(train_loader.batch_size): labelx = label[i] indexx = index[i] cidx = pid[i] if indexx > len(pkl[labelx]) - 1: indexx = len(pkl[labelx]) - 1 a = pkl[labelx][indexx] minpos = np.argmin(ma.masked_where(a == 0, a)) pos_dic = train_set[data_index[cidx][1] + minpos] #print(pos_dic[1]) neg_label = int(labelx) while True: neg_label = random.choice(range(1, 770)) if neg_label is not int(labelx) and os.path.isdir( os.path.join('D:/datasets/veri-split/train', strint(neg_label))) is True: break negative_label = strint(neg_label) neg_cid = pidx[negative_label] neg_index = random.choice(range(0, len(pkl[negative_label]))) neg_dic = train_set[data_index[neg_cid][1] + neg_index] trainX[i] = img[i] trainX[i + train_loader.batch_size] = pos_dic[0] trainX[i + (train_loader.batch_size * 2)] = neg_dic[0] trainY[i] = cidx trainY[i + train_loader.batch_size] = pos_dic[3] trainY[i + (train_loader.batch_size * 2)] = neg_dic[3] #print(trainY) trainX = trainX.cuda() trainY = trainY.cuda() score, feat = model(trainX, trainY) loss = loss_func(score, feat, trainY) loss.backward() optimizer.step() if 'center' in cfg.MODEL.METRIC_LOSS_TYPE: for param in center_criterion.parameters(): param.grad.data *= (1. / cfg.SOLVER.CENTER_LOSS_WEIGHT) optimizer_center.step() acc = (score.max(1)[1] == trainY).float().mean() loss_meter.update(loss.item(), img.shape[0]) acc_meter.update(acc, 1) if (n_iter + 1) % log_period == 0: logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(epoch, (n_iter + 1), len(train_loader), loss_meter.avg, acc_meter.avg, scheduler.get_lr()[0])) end_time = time.time() time_per_batch = (end_time - start_time) / (n_iter + 1) logger.info( "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" .format(epoch, time_per_batch, train_loader.batch_size / time_per_batch)) if epoch % checkpoint_period == 0: torch.save( model.state_dict(), os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch))) if epoch % eval_period == 0: model.eval() for n_iter, (img, vid, camid, _, _) in enumerate(val_loader): with torch.no_grad(): img = img.to(device) feat = model(img) evaluator.update((feat, vid, camid)) cmc, mAP, _, _, _, _, _ = evaluator.compute() logger.info("Validation Results - Epoch: {}".format(epoch)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format( r, cmc[r - 1]))
model = nn.DataParallel(model) model.to(device) loss_meter = AverageMeter() acc_meter = AverageMeter() evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm='yes') model.base._freeze_stages() logger.info('Freezing the stages number:{}'.format(-1)) # train for epoch in range(1, epochs + 1): start_time = time.time() loss_meter.reset() acc_meter.reset() evaluator.reset() scheduler.step() model.train() for n_iter, (img, vid) in enumerate(train_loader): optimizer.zero_grad() optimizer_center.zero_grad() img = img.to(device) target = vid.to(device) feat = model(img, target) loss,score = loss_func(feat, target) loss.backward() optimizer.step() acc = (score.max(1)[1] == target).float().mean() loss_meter.update(loss.item(), img.shape[0]) acc_meter.update(acc, 1)
def train(args): if args.batch_size % args.num_instance != 0: new_batch_size = (args.batch_size // args.num_instance) * args.num_instance print( f"given batch size is {args.batch_size} and num_instances is {args.num_instance}." + f"Batch size must be divided into {args.num_instance}. Batch size will be replaced into {new_batch_size}" ) args.batch_size = new_batch_size # prepare dataset train_loader, val_loader, num_query, train_data_len, num_classes = make_data_loader( args) model = build_model(args, num_classes) print("model size: {:.5f}M".format( sum(p.numel() for p in model.parameters()) / 1e6)) loss_fn, center_criterion = make_loss(args, num_classes) optimizer, optimizer_center = make_optimizer(args, model, center_criterion) if args.cuda: model = model.cuda() if args.amp: if args.center_loss: model, [optimizer, optimizer_center] = \ amp.initialize(model, [optimizer, optimizer_center], opt_level="O1") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() if args.center_loss: center_criterion = center_criterion.cuda() for state in optimizer_center.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() model_state_dict = model.state_dict() optim_state_dict = optimizer.state_dict() if args.center_loss: optim_center_state_dict = optimizer_center.state_dict() center_state_dict = center_criterion.state_dict() reid_evaluator = ReIDEvaluator(args, model, num_query) start_epoch = 0 global_step = 0 if args.pretrain != '': # load pre-trained model weights = torch.load(args.pretrain) model_state_dict = weights["state_dict"] model.load_state_dict(model_state_dict) if args.center_loss: center_criterion.load_state_dict( torch.load(args.pretrain.replace( 'model', 'center_param'))["state_dict"]) if args.resume: start_epoch = weights["epoch"] global_step = weights["global_step"] optimizer.load_state_dict( torch.load(args.pretrain.replace('model', 'optimizer'))["state_dict"]) if args.center_loss: optimizer_center.load_state_dict( torch.load( args.pretrain.replace( 'model', 'optimizer_center'))["state_dict"]) print(f'Start epoch: {start_epoch}, Start step: {global_step}') scheduler = WarmupMultiStepLR(optimizer, args.steps, args.gamma, args.warmup_factor, args.warmup_step, "linear", -1 if start_epoch == 0 else start_epoch) current_epoch = start_epoch best_epoch = 0 best_rank1 = 0 best_mAP = 0 if args.resume: rank, mAP = reid_evaluator.evaluate(val_loader) best_rank1 = rank[0] best_mAP = mAP best_epoch = current_epoch + 1 batch_time = AverageMeter() total_losses = AverageMeter() model_save_dir = os.path.join(args.save_dir, 'ckpts') os.makedirs(model_save_dir, exist_ok=True) summary_writer = SummaryWriter(log_dir=os.path.join( args.save_dir, "tensorboard_log"), purge_step=global_step) def summary_loss(score, feat, labels, top_name='global'): loss = 0.0 losses = loss_fn(score, feat, labels) for loss_name, loss_val in losses.items(): if loss_name.lower() == "accuracy": summary_writer.add_scalar(f"Score/{top_name}/triplet", loss_val, global_step) continue if "dist" in loss_name.lower(): summary_writer.add_histogram(f"Distance/{loss_name}", loss_val, global_step) continue loss += loss_val summary_writer.add_scalar(f"losses/{top_name}/{loss_name}", loss_val, global_step) ohe_labels = torch.zeros_like(score) ohe_labels.scatter_(1, labels.unsqueeze(1), 1.0) cls_score = torch.softmax(score, dim=1) cls_score = torch.sum(cls_score * ohe_labels, dim=1).mean() summary_writer.add_scalar(f"Score/{top_name}/X-entropy", cls_score, global_step) return loss def save_weights(file_name, eph, steps): torch.save( { "state_dict": model_state_dict, "epoch": eph + 1, "global_step": steps }, file_name) torch.save({"state_dict": optim_state_dict}, file_name.replace("model", "optimizer")) if args.center_loss: torch.save({"state_dict": center_state_dict}, file_name.replace("model", "optimizer_center")) torch.save({"state_dict": optim_center_state_dict}, file_name.replace("model", "center_param")) # training start for epoch in range(start_epoch, args.max_epoch): model.train() t0 = time.time() for i, (inputs, labels, _, _) in enumerate(train_loader): if args.cuda: inputs = inputs.cuda() labels = labels.cuda() cls_scores, features = model(inputs, labels) # losses total_loss = summary_loss(cls_scores[0], features[0], labels, 'global') if args.use_local_feat: total_loss += summary_loss(cls_scores[1], features[1], labels, 'local') optimizer.zero_grad() if args.center_loss: optimizer_center.zero_grad() # backward with global loss if args.amp: optimizers = [optimizer] if args.center_loss: optimizers.append(optimizer_center) with amp.scale_loss(total_loss, optimizers) as scaled_loss: scaled_loss.backward() else: with torch.autograd.detect_anomaly(): total_loss.backward() # optimization optimizer.step() if args.center_loss: for name, param in center_criterion.named_parameters(): try: param.grad.data *= (1. / args.center_loss_weight) except AttributeError: continue optimizer_center.step() batch_time.update(time.time() - t0) total_losses.update(total_loss.item()) # learning_rate current_lr = optimizer.param_groups[0]['lr'] summary_writer.add_scalar("lr", current_lr, global_step) t0 = time.time() if (i + 1) % args.log_period == 0: print( f"Epoch: [{epoch}][{i+1}/{train_data_len}] " + f"Batch Time {batch_time.val:.3f} ({batch_time.mean:.3f}) " + f"Total_loss {total_losses.val:.3f} ({total_losses.mean:.3f})" ) global_step += 1 print( f"Epoch: [{epoch}]\tEpoch Time {batch_time.sum:.3f} s\tLoss {total_losses.mean:.3f}\tLr {current_lr:.2e}" ) if args.eval_period > 0 and (epoch + 1) % args.eval_period == 0 or ( epoch + 1) == args.max_epoch: rank, mAP = reid_evaluator.evaluate( val_loader, mode="retrieval" if args.dataset_name == "cub200" else "reid") rank_string = "" for r in (1, 2, 4, 5, 8, 10, 16, 20): rank_string += f"Rank-{r:<3}: {rank[r-1]:.1%}" if r != 20: rank_string += " " summary_writer.add_text("Recall@K", rank_string, global_step) summary_writer.add_scalar("Rank-1", rank[0], (epoch + 1)) rank1 = rank[0] is_best = rank1 > best_rank1 if is_best: best_rank1 = rank1 best_mAP = mAP best_epoch = epoch + 1 if (epoch + 1) % args.save_period == 0 or (epoch + 1) == args.max_epoch: pth_file_name = os.path.join( model_save_dir, f"{args.backbone}_model_{epoch + 1}.pth.tar") save_weights(pth_file_name, eph=epoch, steps=global_step) if is_best: pth_file_name = os.path.join( model_save_dir, f"{args.backbone}_model_best.pth.tar") save_weights(pth_file_name, eph=epoch, steps=global_step) # end epoch current_epoch += 1 batch_time.reset() total_losses.reset() torch.cuda.empty_cache() # update learning rate scheduler.step() print(f"Best rank-1 {best_rank1:.1%}, achived at epoch {best_epoch}") summary_writer.add_hparams( { "dataset_name": args.dataset_name, "triplet_dim": args.triplet_dim, "margin": args.margin, "base_lr": args.base_lr, "use_attn": args.use_attn, "use_mask": args.use_mask, "use_local_feat": args.use_local_feat }, { "mAP": best_mAP, "Rank1": best_rank1 })