def main(): cfg = Config() # Redirect logs to both console and file. if cfg.log_to_file: ReDirectSTD(cfg.stdout_file, 'stdout', False) ReDirectSTD(cfg.stderr_file, 'stderr', False) # Lazily create SummaryWriter writer = None TVT, TMO = set_devices(cfg.sys_device_ids) if cfg.seed is not None: set_seed(cfg.seed) # Dump the configurations to log. import pprint print('-' * 60) print('cfg.__dict__') pprint.pprint(cfg.__dict__) print('-' * 60) ########### # Dataset # ########### train_set = create_dataset(**cfg.train_set_kwargs) test_sets = [] test_set_names = [] if cfg.dataset == 'combined': for name in ['market1501', 'cuhk03', 'duke']: cfg.test_set_kwargs['name'] = name test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(name) else: test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(cfg.dataset) ########### # Models # ########### model = Model(local_conv_out_channels=cfg.local_conv_out_channels, num_classes=len(train_set.ids2labels)) # Model wrapper model_w = DataParallel(model) ############################# # Criteria and Optimizers # ############################# #id_criterion = nn.CrossEntropyLoss() id_criterion = SoftmaxEntropyLoss() g_tri_loss = TripletLoss(margin=cfg.global_margin) l_tri_loss = TripletLoss(margin=cfg.local_margin) optimizer = optim.Adam(model.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay) # Bind them together just to save some codes in the following usage. modules_optims = [model, optimizer] ################################ # May Resume Models and Optims # ################################ if cfg.resume: resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file) # May Transfer Models and Optims to Specified Device. Transferring optimizer # is to cope with the case when you load the checkpoint to a new device. TMO(modules_optims) ######## # Test # ######## def test(load_model_weight=False): if load_model_weight: if cfg.model_weight_file != '': map_location = (lambda storage, loc: storage) sd = torch.load(cfg.model_weight_file, map_location=map_location) load_state_dict(model, sd) print('Loaded model weights from {}'.format( cfg.model_weight_file)) else: load_ckpt(modules_optims, cfg.ckpt_file) use_local_distance = (cfg.l_loss_weight > 0) \ and cfg.local_dist_own_hard_sample for test_set, name in zip(test_sets, test_set_names): test_set.set_feat_func(ExtractFeature(model_w, TVT)) print('\n=========> Test on dataset: {} <=========\n'.format(name)) test_set.eval(normalize_feat=cfg.normalize_feature, use_local_distance=use_local_distance) if cfg.only_test: test(load_model_weight=True) return ############ # Training # ############ start_ep = resume_ep if cfg.resume else 0 for ep in range(start_ep, cfg.total_epochs): # Adjust Learning Rate if cfg.lr_decay_type == 'exp': adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs, cfg.exp_decay_at_epoch) else: adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1, cfg.staircase_decay_at_epochs, cfg.staircase_decay_multiply_factor) may_set_mode(modules_optims, 'train') g_prec_meter = AverageMeter() g_m_meter = AverageMeter() g_dist_ap_meter = AverageMeter() g_dist_an_meter = AverageMeter() g_loss_meter = AverageMeter() l_prec_meter = AverageMeter() l_m_meter = AverageMeter() l_dist_ap_meter = AverageMeter() l_dist_an_meter = AverageMeter() l_loss_meter = AverageMeter() id_loss_meter = AverageMeter() sift_loss_meter = AverageMeter() loss_meter = AverageMeter() ep_st = time.time() step = 0 epoch_done = False while not epoch_done: step += 1 step_st = time.time() ims, im_names, labels, cam_lables, mirrored, epoch_done = train_set.next_batch( ) ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) labels_var = Variable(labels_t) feat, global_feat, local_feat, logits = model_w(ims_var) sift_func = ExtractSift() sift = torch.from_numpy(sift_func(ims_var)).cuda() g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss( g_tri_loss, global_feat, labels_t, normalize_feature=cfg.normalize_feature) if cfg.l_loss_weight == 0: l_loss = 0 elif cfg.local_dist_own_hard_sample: # Let local distance find its own hard samples. l_loss, l_dist_ap, l_dist_an, _ = local_loss( l_tri_loss, local_feat, None, None, labels_t, normalize_feature=cfg.normalize_feature) else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t, normalize_feature=cfg.normalize_feature) id_loss = 0 if cfg.id_loss_weight > 0: id_loss = id_criterion(logits, labels_var) sift_loss = 0 if cfg.sift_loss_weight > 0: sift_loss = torch.norm( F.softmax(global_feat, dim=1) - F.softmax(sift, dim=1)) loss = g_loss * cfg.g_loss_weight \ + l_loss * cfg.l_loss_weight \ + id_loss * cfg.id_loss_weight \ + sift_loss * cfg.sift_loss_weight optimizer.zero_grad() loss.backward() optimizer.step() ############ # Step Log # ############ # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) if cfg.id_loss_weight > 0: id_loss_meter.update(to_scalar(id_loss)) if cfg.sift_loss_weight > 0: sift_loss_meter.update(to_scalar(sift_loss)) loss_meter.update(to_scalar(loss)) if step % cfg.log_steps == 0: time_log = '\tStep {}/Ep {}, {:.2f}s'.format( step, ep + 1, time.time() - step_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.val, g_m_meter.val, g_dist_ap_meter.val, g_dist_an_meter.val, g_loss_meter.val, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.val, l_m_meter.val, l_dist_ap_meter.val, l_dist_an_meter.val, l_loss_meter.val, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.val)) else: id_log = '' if cfg.sift_loss_weight > 0: sift_log = (', sL {:.4f}'.format(sift_loss_meter.val)) else: sift_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.val) log = time_log + \ g_log + l_log + id_log + \ sift_log + total_loss_log print(log) ############# # Epoch Log # ############# time_log = 'Ep {}, {:.2f}s'.format( ep + 1, time.time() - ep_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.avg, g_m_meter.avg, g_dist_ap_meter.avg, g_dist_an_meter.avg, g_loss_meter.avg, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.avg, l_m_meter.avg, l_dist_ap_meter.avg, l_dist_an_meter.avg, l_loss_meter.avg, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.avg)) else: id_log = '' if cfg.sift_loss_weight > 0: sift_log = (', sL {:.4f}'.format(sift_loss_meter.avg)) else: sift_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.avg) log = time_log + \ g_log + l_log + id_log + \ sift_log + total_loss_log print(log) # Log to TensorBoard if cfg.log_to_file: if writer is None: writer = SummaryWriter( log_dir=osp.join(cfg.exp_dir, 'tensorboard')) writer.add_scalars( 'loss', dict( global_loss=g_loss_meter.avg, local_loss=l_loss_meter.avg, id_loss=id_loss_meter.avg, sift_loss=sift_loss_meter.avg, loss=loss_meter.avg, ), ep) writer.add_scalars( 'tri_precision', dict( global_precision=g_prec_meter.avg, local_precision=l_prec_meter.avg, ), ep) writer.add_scalars( 'satisfy_margin', dict( global_satisfy_margin=g_m_meter.avg, local_satisfy_margin=l_m_meter.avg, ), ep) writer.add_scalars( 'global_dist', dict( global_dist_ap=g_dist_ap_meter.avg, global_dist_an=g_dist_an_meter.avg, ), ep) writer.add_scalars( 'local_dist', dict( local_dist_ap=l_dist_ap_meter.avg, local_dist_an=l_dist_an_meter.avg, ), ep) # save ckpt if cfg.log_to_file: save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file) ######## # Test # ######## test(load_model_weight=False)
def main(): cfg = Config() # Redirect logs to both console and file. if cfg.log_to_file: ReDirectSTD(cfg.stdout_file, 'stdout', False) ReDirectSTD(cfg.stderr_file, 'stderr', False) # Lazily create SummaryWriter writer = None TVT, TMO = set_devices(cfg.sys_device_ids) if cfg.seed is not None: set_seed(cfg.seed) # Dump the configurations to log. import pprint print('-' * 60) print('cfg.__dict__') pprint.pprint(cfg.__dict__) print('-' * 60) ########### # Dataset # ########### train_set = create_dataset(**cfg.train_set_kwargs) test_sets = [] test_set_names = [] if cfg.dataset == 'combined': for name in ['market1501', 'cuhk03', 'duke']: cfg.test_set_kwargs['name'] = name test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(name) else: test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(cfg.dataset) ########### # Models # ########### if cfg.dataset == 'market1501': cams = 6 elif cfg.dataset == 'cuhk03': cams = 2 else: cams = 8 ids = len(train_set.ids2labels) model = Model(local_conv_out_channels=cfg.local_conv_out_channels, num_classes=len(train_set.ids2labels), cam_classes= cams) # Model wrapper model_w = DataParallel(model) ############################# # Criteria and Optimizers # ############################# view_criterion = SoftmaxEntropyLoss() id_criterion = nn.CrossEntropyLoss() id_criterion1 = SoftmaxEntropyLoss() id_criterion2 = SoftmaxEntropyLoss() g_tri_loss = TripletLoss(margin=cfg.global_margin) l_tri_loss = TripletLoss(margin=cfg.local_margin) center_loss = CenterLoss(num_classes=len(train_set.ids2labels), feat_dim=2048, use_gpu=True) optimizer = optim.Adam(model.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay) optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.001) # Bind them together just to save some codes in the following usage. modules_optims = [model, optimizer] ################################ # May Resume Models and Optims # ################################ if cfg.resume: resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file) # May Transfer Models and Optims to Specified Device. Transferring optimizer # is to cope with the case when you load the checkpoint to a new device. TMO(modules_optims) ######## # Test # ######## def test(load_model_weight=False): if load_model_weight: if cfg.model_weight_file != '': map_location = (lambda storage, loc: storage) sd = torch.load(cfg.model_weight_file, map_location=map_location) load_state_dict(model, sd) print('Loaded model weights from {}'.format(cfg.model_weight_file)) else: load_ckpt(modules_optims, cfg.ckpt_file) use_local_distance = (cfg.l_loss_weight > 0) \ and cfg.local_dist_own_hard_sample for test_set, name in zip(test_sets, test_set_names): test_set.set_feat_func(ExtractFeature(model_w, TVT)) print('\n=========> Test on dataset: {} <=========\n'.format(name)) test_set.eval( normalize_feat=cfg.normalize_feature, use_local_distance=use_local_distance) if cfg.only_test: test(load_model_weight=True) return ############ # Training # ############ start_ep = resume_ep if cfg.resume else 0 for ep in range(start_ep, cfg.total_epochs): # Adjust Learning Rate if cfg.lr_decay_type == 'exp': adjust_lr_exp( optimizer, cfg.base_lr, ep + 1, cfg.total_epochs, cfg.exp_decay_at_epoch) if cfg.lr_decay_type == 'exp': adjust_lr_exp( optimizer_centloss, cfg.base_lr, ep + 1, cfg.total_epochs, cfg.exp_decay_at_epoch) else: adjust_lr_staircase( optimizer, cfg.base_lr, ep + 1, cfg.staircase_decay_at_epochs, cfg.staircase_decay_multiply_factor) adjust_lr_staircase( optimizer_centloss, cfg.base_lr, ep + 1, cfg.staircase_decay_at_epochs, cfg.staircase_decay_multiply_factor) may_set_mode(modules_optims, 'train') g_prec_meter = AverageMeter() g_m_meter = AverageMeter() g_dist_ap_meter = AverageMeter() g_dist_an_meter = AverageMeter() g_loss_meter = AverageMeter() l_prec_meter = AverageMeter() l_m_meter = AverageMeter() l_dist_ap_meter = AverageMeter() l_dist_an_meter = AverageMeter() l_loss_meter = AverageMeter() id_loss_meter = AverageMeter() sift_loss_meter = AverageMeter() c_loss_meter = AverageMeter() a_loss_meter = AverageMeter() view_loss_meter = AverageMeter() loss_meter = AverageMeter() ep_st = time.time() step = 0 epoch_done = False while not epoch_done: step += 1 step_st = time.time() ims, im_names, labels, cam_labels, mirrored, epoch_done = train_set.next_batch() ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) labels_var1 = Variable(labels_t) ###########################################id labels######################################## m = torch.LongTensor(labels) batchsize = cfg.ids_per_batch * cfg.ims_per_id n = m.view(batchsize,1) #96 id_onehot = torch.FloatTensor(batchsize, ids) id_onehot.zero_() id_onehot.scatter_(1, n, 1) id_onehot = id_onehot*0.8 #0.8 id_po = 0.2 / (ids-1) id_piil = torch.zeros(batchsize,ids) id_pik = id_piil + id_po labels_var = torch.where(id_onehot>id_pik,id_onehot,id_pik).cuda() #########################################cam labels##################################### if cfg.dataset == 'cuhk03': cam_labels_t = TVT(torch.from_numpy(cam_labels).long()) else: cam_labels = cam_labels - 1 cam_labels_t = TVT(torch.from_numpy(cam_labels).long()) b = torch.LongTensor(cam_labels) batchsize = cfg.ids_per_batch * cfg.ims_per_id c = b.view(batchsize,1) #96 cam_onehot = torch.FloatTensor(batchsize, cams) cam_onehot.zero_() cam_onehot.scatter_(1, c, 1) cam_onehot = cam_onehot*0.8 #0.8 cam_po = 0.2 / (cams-1) cam_piil = torch.zeros(batchsize,cams) cam_pik = cam_piil + cam_po cam_labels_var = torch.where(cam_onehot>cam_pik,cam_onehot,cam_pik).cuda() #cam_labels_var = Variable(cam_labels_t) feat, feat_part1, feat_part2, global_feat, local_feat, feature, logits, logits1, view_logits = model_w(ims_var) sift_func = ExtractSift() sift = torch.from_numpy(sift_func(ims_var)).cuda() part = {} num_part = 2 for i in range(num_part): part[i] = feature[i] #print(feat.size()) [90,2048,16,8] #print(feat_part1.size()) [90,512,16,8] map1 = torch.mean(feat_part1,1).view(feat_part1.size(0),-1) map2 = torch.mean(feat_part2,1).view(feat_part2.size(0),-1) #print(map1.size()) [90,128] g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss( g_tri_loss, global_feat, labels_t, normalize_feature=cfg.normalize_feature) if cfg.l_loss_weight == 0: l_loss = 0 elif cfg.local_dist_own_hard_sample: # Let local distance find its own hard samples. l_loss, l_dist_ap, l_dist_an, _ = local_loss( l_tri_loss, local_feat, None, None, labels_t, normalize_feature=cfg.normalize_feature) else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t, normalize_feature=cfg.normalize_feature) id_loss = 0 if cfg.id_loss_weight > 0: id_loss = id_criterion2(logits1, labels_var) #id_loss = id_criterion2(logits1, labels_var) + id_criterion1(logits, labels_var) a_loss = 0 if cfg.align_loss_weight > 0: #a_loss = torch.norm(normalize(map1, axis=-1) - normalize(map2, axis=-1)) a_loss = torch.norm(torch.sigmoid(map1) - torch.sigmoid(map2)) sift_loss = 0 if cfg.sift_loss_weight > 0: #sift_loss = torch.norm(normalize(part[2], axis=-1)-normalize(sift, axis=-1)) sift_loss = torch.norm(F.softmax(global_feat,dim=1)-F.softmax(sift,dim=1)) c_loss = 0 if cfg.c_loss_weight > 0: c_loss = center_loss(normalize(global_feat, axis=-1), labels_var1) view_loss = 0 if cfg.view_loss_weight > 0: view_loss = view_criterion(view_logits, cam_labels_var) loss = g_loss * cfg.g_loss_weight \ + l_loss * cfg.l_loss_weight \ + id_loss * cfg.id_loss_weight \ + sift_loss * cfg.sift_loss_weight \ + c_loss * cfg.c_loss_weight \ + view_loss * cfg.view_loss_weight \ + a_loss * cfg.align_loss_weight optimizer.zero_grad() optimizer_centloss.zero_grad() loss.backward() optimizer.step() for param in center_loss.parameters(): param.grad.data *= (1 / cfg.c_loss_weight) optimizer_centloss.step() ############ # Step Log # ############ # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) if cfg.id_loss_weight > 0: id_loss_meter.update(to_scalar(id_loss)) if cfg.sift_loss_weight > 0: sift_loss_meter.update(to_scalar(sift_loss)) if cfg.c_loss_weight > 0: c_loss_meter.update(to_scalar(c_loss)) if cfg.view_loss_weight > 0: view_loss_meter.update(to_scalar(view_loss)) if cfg.align_loss_weight > 0: a_loss_meter.update(to_scalar(a_loss)) loss_meter.update(to_scalar(loss)) if step % cfg.log_steps == 0: time_log = '\tStep {}/Ep {}, {:.2f}s'.format( step, ep + 1, time.time() - step_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.val, g_m_meter.val, g_dist_ap_meter.val, g_dist_an_meter.val, g_loss_meter.val, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.val, l_m_meter.val, l_dist_ap_meter.val, l_dist_an_meter.val, l_loss_meter.val, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.val)) else: id_log = '' if cfg.sift_loss_weight > 0: sift_log = (', sL {:.4f}'.format(sift_loss_meter.val)) else: sift_log = '' if cfg.c_loss_weight > 0: c_log = (', cL {:.4f}'.format(c_loss_meter.val)) else: c_log = '' if cfg.view_loss_weight > 0: view_log = (', viewL {:.4f}'.format(view_loss_meter.val)) else: view_log = '' if cfg.align_loss_weight > 0: a_log = (', aL {:.4f}'.format(a_loss_meter.val)) else: a_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.val) log = time_log + \ g_log + l_log + id_log + a_log + \ sift_log + c_log + view_log + \ total_loss_log print(log) ############# # Epoch Log # ############# time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.avg, g_m_meter.avg, g_dist_ap_meter.avg, g_dist_an_meter.avg, g_loss_meter.avg, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.avg, l_m_meter.avg, l_dist_ap_meter.avg, l_dist_an_meter.avg, l_loss_meter.avg, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.avg)) else: id_log = '' if cfg.sift_loss_weight > 0: sift_log = (', sL {:.4f}'.format(sift_loss_meter.avg)) else: sift_log = '' if cfg.c_loss_weight > 0: c_log = (', cL {:.4f}'.format(c_loss_meter.avg)) else: c_log = '' if cfg.view_loss_weight > 0: view_log = (', viewL {:.4f}'.format(view_loss_meter.avg)) else: view_log = '' if cfg.align_loss_weight > 0: a_log = (', aL {:.4f}'.format(a_loss_meter.avg)) else: a_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.avg) log = time_log + \ g_log + l_log + id_log + a_log + \ sift_log + c_log + view_log + \ total_loss_log print(log) # Log to TensorBoard if cfg.log_to_file: if writer is None: writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard')) writer.add_scalars( 'loss', dict(global_loss=g_loss_meter.avg, local_loss=l_loss_meter.avg, id_loss=id_loss_meter.avg, sift_loss=sift_loss_meter.avg, c_loss=c_loss_meter.avg, view_loss=view_loss_meter.avg, a_loss=a_loss_meter.avg, loss=loss_meter.avg, ), ep) writer.add_scalars( 'tri_precision', dict(global_precision=g_prec_meter.avg, local_precision=l_prec_meter.avg, ), ep) writer.add_scalars( 'satisfy_margin', dict(global_satisfy_margin=g_m_meter.avg, local_satisfy_margin=l_m_meter.avg, ), ep) writer.add_scalars( 'global_dist', dict(global_dist_ap=g_dist_ap_meter.avg, global_dist_an=g_dist_an_meter.avg, ), ep) writer.add_scalars( 'local_dist', dict(local_dist_ap=l_dist_ap_meter.avg, local_dist_an=l_dist_an_meter.avg, ), ep) # save ckpt if cfg.log_to_file: save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file) ######## # Test # ######## test(load_model_weight=False)