def main(): cfg = Config() # Redirect logs to both console and file. if cfg.log_to_file: ReDirectSTD(cfg.stdout_file, 'stdout', False) ReDirectSTD(cfg.stderr_file, 'stderr', False) # Lazily create SummaryWriter writer = None TVT, TMO = set_devices(cfg.sys_device_ids) if cfg.seed is not None: set_seed(cfg.seed) # Dump the configurations to log. import pprint print('-' * 60) print('cfg.__dict__') pprint.pprint(cfg.__dict__) print('-' * 60) ########### # Dataset # ########### train_set = create_dataset(**cfg.train_set_kwargs) test_sets = [] test_set_names = [] if cfg.dataset == 'combined': for name in ['market1501', 'cuhk03', 'duke']: cfg.test_set_kwargs['name'] = name test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(name) else: test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(cfg.dataset) ########### # Models # ########### model = Model(local_conv_out_channels=cfg.local_conv_out_channels, num_classes=len(train_set.ids2labels)) # Model wrapper model_w = DataParallel(model) ############################# # Criteria and Optimizers # ############################# id_criterion = nn.CrossEntropyLoss() g_tri_loss = TripletLoss(margin=cfg.global_margin) l_tri_loss = TripletLoss(margin=cfg.local_margin) optimizer = optim.Adam(model.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay) # Bind them together just to save some codes in the following usage. modules_optims = [model, optimizer] ################################ # May Resume Models and Optims # ################################ if cfg.resume: resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file) # May Transfer Models and Optims to Specified Device. Transferring optimizer # is to cope with the case when you load the checkpoint to a new device. TMO(modules_optims) ######## # Test # ######## def test(load_model_weight=False): if load_model_weight: if cfg.model_weight_file != '': map_location = (lambda storage, loc: storage) sd = torch.load(cfg.model_weight_file, map_location=map_location) load_state_dict(model, sd) print('Loaded model weights from {}'.format(cfg.model_weight_file)) else: load_ckpt(modules_optims, cfg.ckpt_file) use_local_distance = (cfg.l_loss_weight > 0) \ and cfg.local_dist_own_hard_sample for test_set, name in zip(test_sets, test_set_names): test_set.set_feat_func(ExtractFeature(model_w, TVT)) print('\n=========> Test on dataset: {} <=========\n'.format(name)) test_set.eval( normalize_feat=cfg.normalize_feature, use_local_distance=use_local_distance) if cfg.only_test: test(load_model_weight=True) return ############ # Training # ############ start_ep = resume_ep if cfg.resume else 0 for ep in range(start_ep, cfg.total_epochs): # Adjust Learning Rate if cfg.lr_decay_type == 'exp': adjust_lr_exp( optimizer, cfg.base_lr, ep + 1, cfg.total_epochs, cfg.exp_decay_at_epoch) else: adjust_lr_staircase( optimizer, cfg.base_lr, ep + 1, cfg.staircase_decay_at_epochs, cfg.staircase_decay_multiply_factor) may_set_mode(modules_optims, 'train') g_prec_meter = AverageMeter() g_m_meter = AverageMeter() g_dist_ap_meter = AverageMeter() g_dist_an_meter = AverageMeter() g_loss_meter = AverageMeter() l_prec_meter = AverageMeter() l_m_meter = AverageMeter() l_dist_ap_meter = AverageMeter() l_dist_an_meter = AverageMeter() l_loss_meter = AverageMeter() id_loss_meter = AverageMeter() loss_meter = AverageMeter() ep_st = time.time() step = 0 epoch_done = False while not epoch_done: step += 1 step_st = time.time() ims, im_names, labels, mirrored, epoch_done = train_set.next_batch() ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) labels_var = Variable(labels_t) global_feat, local_feat, logits = model_w(ims_var) g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss( g_tri_loss, global_feat, labels_t, normalize_feature=cfg.normalize_feature) if cfg.l_loss_weight == 0: l_loss = 0 elif cfg.local_dist_own_hard_sample: # Let local distance find its own hard samples. l_loss, l_dist_ap, l_dist_an, _ = local_loss( l_tri_loss, local_feat, None, None, labels_t, normalize_feature=cfg.normalize_feature) else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t, normalize_feature=cfg.normalize_feature) id_loss = 0 if cfg.id_loss_weight > 0: id_loss = id_criterion(logits, labels_var) loss = g_loss * cfg.g_loss_weight \ + l_loss * cfg.l_loss_weight \ + id_loss * cfg.id_loss_weight optimizer.zero_grad() loss.backward() optimizer.step() ############ # Step Log # ############ # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) if cfg.id_loss_weight > 0: id_loss_meter.update(to_scalar(id_loss)) loss_meter.update(to_scalar(loss)) if step % cfg.log_steps == 0: time_log = '\tStep {}/Ep {}, {:.2f}s'.format( step, ep + 1, time.time() - step_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.val, g_m_meter.val, g_dist_ap_meter.val, g_dist_an_meter.val, g_loss_meter.val, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.val, l_m_meter.val, l_dist_ap_meter.val, l_dist_an_meter.val, l_loss_meter.val, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.val)) else: id_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.val) log = time_log + \ g_log + l_log + id_log + \ total_loss_log print(log) ############# # Epoch Log # ############# time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.avg, g_m_meter.avg, g_dist_ap_meter.avg, g_dist_an_meter.avg, g_loss_meter.avg, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.avg, l_m_meter.avg, l_dist_ap_meter.avg, l_dist_an_meter.avg, l_loss_meter.avg, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.avg)) else: id_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.avg) log = time_log + \ g_log + l_log + id_log + \ total_loss_log print(log) # Log to TensorBoard if cfg.log_to_file: if writer is None: writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard')) writer.add_scalars( 'loss', dict(global_loss=g_loss_meter.avg, local_loss=l_loss_meter.avg, id_loss=id_loss_meter.avg, loss=loss_meter.avg, ), ep) writer.add_scalars( 'tri_precision', dict(global_precision=g_prec_meter.avg, local_precision=l_prec_meter.avg, ), ep) writer.add_scalars( 'satisfy_margin', dict(global_satisfy_margin=g_m_meter.avg, local_satisfy_margin=l_m_meter.avg, ), ep) writer.add_scalars( 'global_dist', dict(global_dist_ap=g_dist_ap_meter.avg, global_dist_an=g_dist_an_meter.avg, ), ep) writer.add_scalars( 'local_dist', dict(local_dist_ap=l_dist_ap_meter.avg, local_dist_an=l_dist_an_meter.avg, ), ep) # save ckpt if cfg.log_to_file: save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file) ######## # Test # ######## test(load_model_weight=False)
def main(): cfg = Config() # Redirect logs to both console and file. if cfg.log_to_file: ReDirectSTD(cfg.stdout_file, 'stdout', False) ReDirectSTD(cfg.stderr_file, 'stderr', False) # Lazily create SummaryWriter writer = None TVTs, TMOs, relative_device_ids = set_devices_for_ml(cfg.sys_device_ids) if cfg.seed is not None: set_seed(cfg.seed) # Dump the configurations to log. import pprint print('-' * 60) print('cfg.__dict__') pprint.pprint(cfg.__dict__) print('-' * 60) ########### # Dataset # ########### train_set = create_dataset(**cfg.train_set_kwargs) test_sets = [] test_set_names = [] if cfg.dataset == 'combined': for name in ['market1501', 'cuhk03', 'duke']: cfg.test_set_kwargs['name'] = name test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(name) else: test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(cfg.dataset) ########### # Models # ########### models = [Model(local_conv_out_channels=cfg.local_conv_out_channels, num_classes=len(train_set.ids2labels)) for _ in range(cfg.num_models)] # Model wrappers model_ws = [DataParallel(models[i], device_ids=relative_device_ids[i]) for i in range(cfg.num_models)] ############################# # Criteria and Optimizers # ############################# id_criterion = nn.CrossEntropyLoss() g_tri_loss = TripletLoss(margin=cfg.global_margin) l_tri_loss = TripletLoss(margin=cfg.local_margin) optimizers = [optim.Adam(m.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay) for m in models] # Bind them together just to save some codes in the following usage. modules_optims = models + optimizers ################################ # May Resume Models and Optims # ################################ if cfg.resume: resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file) # May Transfer Models and Optims to Specified Device. Transferring optimizers # is to cope with the case when you load the checkpoint to a new device. for TMO, model, optimizer in zip(TMOs, models, optimizers): TMO([model, optimizer]) ######## # Test # ######## # Test each model using different distance settings. def test(load_model_weight=False): if load_model_weight: load_ckpt(modules_optims, cfg.ckpt_file) use_local_distance = (cfg.l_loss_weight > 0) \ and cfg.local_dist_own_hard_sample for i, (model_w, TVT) in enumerate(zip(model_ws, TVTs)): for test_set, name in zip(test_sets, test_set_names): test_set.set_feat_func(ExtractFeature(model_w, TVT)) print('\n=========> Test Model #{} on dataset: {} <=========\n' .format(i + 1, name)) test_set.eval( normalize_feat=cfg.normalize_feature, use_local_distance=use_local_distance) if cfg.only_test: test(load_model_weight=True) return ############ # Training # ############ # Storing things that can be accessed cross threads. ims_list = [None for _ in range(cfg.num_models)] labels_list = [None for _ in range(cfg.num_models)] done_list1 = [False for _ in range(cfg.num_models)] done_list2 = [False for _ in range(cfg.num_models)] probs_list = [None for _ in range(cfg.num_models)] g_dist_mat_list = [None for _ in range(cfg.num_models)] l_dist_mat_list = [None for _ in range(cfg.num_models)] # Two phases for each model: # 1) forward and single-model loss; # 2) further add mutual loss and backward. # The 2nd phase is only ready to start when the 1st is finished for # all models. run_event1 = threading.Event() run_event2 = threading.Event() # This event is meant to be set to stop threads. However, as I found, with # `daemon` set to true when creating threads, manually stopping is # unnecessary. I guess some main-thread variables required by sub-threads # are destroyed when the main thread ends, thus the sub-threads throw errors # and exit too. # Real reason should be further explored. exit_event = threading.Event() # The function to be called by threads. def thread_target(i): while not exit_event.isSet(): # If the run event is not set, the thread just waits. if not run_event1.wait(0.001): continue ###################################### # Phase 1: Forward and Separate Loss # ###################################### TVT = TVTs[i] model_w = model_ws[i] ims = ims_list[i] labels = labels_list[i] optimizer = optimizers[i] ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) labels_var = Variable(labels_t) global_feat, local_feat, logits = model_w(ims_var) probs = F.softmax(logits) log_probs = F.log_softmax(logits) g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss( g_tri_loss, global_feat, labels_t, normalize_feature=cfg.normalize_feature) if cfg.l_loss_weight == 0: l_loss, l_dist_mat = 0, 0 elif cfg.local_dist_own_hard_sample: # Let local distance find its own hard samples. l_loss, l_dist_ap, l_dist_an, l_dist_mat = local_loss( l_tri_loss, local_feat, None, None, labels_t, normalize_feature=cfg.normalize_feature) else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t, normalize_feature=cfg.normalize_feature) l_dist_mat = 0 id_loss = 0 if cfg.id_loss_weight > 0: id_loss = id_criterion(logits, labels_var) probs_list[i] = probs g_dist_mat_list[i] = g_dist_mat l_dist_mat_list[i] = l_dist_mat done_list1[i] = True # Wait for event to be set, meanwhile checking if need to exit. while True: phase2_ready = run_event2.wait(0.001) if exit_event.isSet(): return if phase2_ready: break ##################################### # Phase 2: Mutual Loss and Backward # ##################################### # Probability Mutual Loss (KL Loss) pm_loss = 0 if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): for j in range(cfg.num_models): if j != i: pm_loss += F.kl_div(log_probs, TVT(probs_list[j]).detach(), False) pm_loss /= 1. * (cfg.num_models - 1) * len(ims) # Global Distance Mutual Loss (L2 Loss) gdm_loss = 0 if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): for j in range(cfg.num_models): if j != i: gdm_loss += torch.sum(torch.pow( g_dist_mat - TVT(g_dist_mat_list[j]).detach(), 2)) gdm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims) # Local Distance Mutual Loss (L2 Loss) ldm_loss = 0 if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): for j in range(cfg.num_models): if j != i: ldm_loss += torch.sum(torch.pow( l_dist_mat - TVT(l_dist_mat_list[j]).detach(), 2)) ldm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims) loss = g_loss * cfg.g_loss_weight \ + l_loss * cfg.l_loss_weight \ + id_loss * cfg.id_loss_weight \ + pm_loss * cfg.pm_loss_weight \ + gdm_loss * cfg.gdm_loss_weight \ + ldm_loss * cfg.ldm_loss_weight optimizer.zero_grad() loss.backward() optimizer.step() ################################## # Step Log For One of the Models # ################################## # These meters are outer-scope variables # Just record for the first model if i == 0: # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) if cfg.id_loss_weight > 0: id_loss_meter.update(to_scalar(id_loss)) if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): pm_loss_meter.update(to_scalar(pm_loss)) if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): gdm_loss_meter.update(to_scalar(gdm_loss)) if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): ldm_loss_meter.update(to_scalar(ldm_loss)) loss_meter.update(to_scalar(loss)) ################### # End Up One Step # ################### run_event1.clear() run_event2.clear() done_list2[i] = True threads = [] for i in range(cfg.num_models): thread = threading.Thread(target=thread_target, args=(i,)) # Set the thread in daemon mode, so that the main program ends normally. thread.daemon = True thread.start() threads.append(thread) start_ep = resume_ep if cfg.resume else 0 for ep in range(start_ep, cfg.total_epochs): # Adjust Learning Rate for optimizer in optimizers: if cfg.lr_decay_type == 'exp': adjust_lr_exp( optimizer, cfg.base_lr, ep + 1, cfg.total_epochs, cfg.exp_decay_at_epoch) else: adjust_lr_staircase( optimizer, cfg.base_lr, ep + 1, cfg.staircase_decay_at_epochs, cfg.staircase_decay_multiply_factor) may_set_mode(modules_optims, 'train') epoch_done = False g_prec_meter = AverageMeter() g_m_meter = AverageMeter() g_dist_ap_meter = AverageMeter() g_dist_an_meter = AverageMeter() g_loss_meter = AverageMeter() l_prec_meter = AverageMeter() l_m_meter = AverageMeter() l_dist_ap_meter = AverageMeter() l_dist_an_meter = AverageMeter() l_loss_meter = AverageMeter() id_loss_meter = AverageMeter() # Global Distance Mutual Loss gdm_loss_meter = AverageMeter() # Local Distance Mutual Loss ldm_loss_meter = AverageMeter() # Probability Mutual Loss pm_loss_meter = AverageMeter() loss_meter = AverageMeter() ep_st = time.time() step = 0 while not epoch_done: step += 1 step_st = time.time() ims, im_names, labels, mirrored, epoch_done = train_set.next_batch() for i in range(cfg.num_models): ims_list[i] = ims labels_list[i] = labels done_list1[i] = False done_list2[i] = False run_event1.set() # Waiting for phase 1 done while not all(done_list1): continue run_event2.set() # Waiting for phase 2 done while not all(done_list2): continue ############ # Step Log # ############ if step % cfg.log_steps == 0: time_log = '\tStep {}/Ep {}, {:.2f}s'.format( step, ep + 1, time.time() - step_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.val, g_m_meter.val, g_dist_ap_meter.val, g_dist_an_meter.val, g_loss_meter.val, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.val, l_m_meter.val, l_dist_ap_meter.val, l_dist_an_meter.val, l_loss_meter.val, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.val)) else: id_log = '' if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): pm_log = (', pmL {:.4f}'.format(pm_loss_meter.val)) else: pm_log = '' if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.val)) else: gdm_log = '' if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.val)) else: ldm_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.val) log = time_log + \ g_log + l_log + id_log + \ pm_log + gdm_log + ldm_log + \ total_loss_log print(log) ############# # Epoch Log # ############# time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.avg, g_m_meter.avg, g_dist_ap_meter.avg, g_dist_an_meter.avg, g_loss_meter.avg, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.avg, l_m_meter.avg, l_dist_ap_meter.avg, l_dist_an_meter.avg, l_loss_meter.avg, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.avg)) else: id_log = '' if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): pm_log = (', pmL {:.4f}'.format(pm_loss_meter.avg)) else: pm_log = '' if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.avg)) else: gdm_log = '' if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.avg)) else: ldm_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.avg) log = time_log + \ g_log + l_log + id_log + \ pm_log + gdm_log + ldm_log + \ total_loss_log print(log) # Log to TensorBoard if cfg.log_to_file: if writer is None: writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard')) writer.add_scalars( 'loss', dict(global_loss=g_loss_meter.avg, local_loss=l_loss_meter.avg, id_loss=id_loss_meter.avg, pm_loss=pm_loss_meter.avg, gdm_loss=gdm_loss_meter.avg, ldm_loss=ldm_loss_meter.avg, loss=loss_meter.avg, ), ep) writer.add_scalars( 'tri_precision', dict(global_precision=g_prec_meter.avg, local_precision=l_prec_meter.avg, ), ep) writer.add_scalars( 'satisfy_margin', dict(global_satisfy_margin=g_m_meter.avg, local_satisfy_margin=l_m_meter.avg, ), ep) writer.add_scalars( 'global_dist', dict(global_dist_ap=g_dist_ap_meter.avg, global_dist_an=g_dist_an_meter.avg, ), ep) writer.add_scalars( 'local_dist', dict(local_dist_ap=l_dist_ap_meter.avg, local_dist_an=l_dist_an_meter.avg, ), ep) # save ckpt if cfg.log_to_file: save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file) ######## # Test # ######## test(load_model_weight=False)
def main(): # reranking_mAP_list = [] cfg = Config() # Redirect logs to both console and file. if cfg.log_to_file: ReDirectSTD(cfg.stdout_file, 'stdout', False) ReDirectSTD(cfg.stderr_file, 'stderr', False) # Lazily create SummaryWriter writer = None TVT, TMO = set_devices(cfg.sys_device_ids) if cfg.seed is not None: set_seed(cfg.seed) # Dump the configurations to log. import pprint print('-' * 60) print('cfg.__dict__') pprint.pprint(cfg.__dict__) print('-' * 60) ########### # Dataset # ########### train_set = create_dataset(**cfg.train_set_kwargs) test_sets = [] test_set_names = [] if cfg.dataset == 'combined': for name in ['market1501', 'cuhk03', 'duke']: cfg.test_set_kwargs['name'] = name test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(name) else: test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(cfg.dataset) ########### # Models # ########### model = Model(local_conv_out_channels=cfg.local_conv_out_channels, num_classes=len(train_set.ids2labels)) # Model wrapper model_w = DataParallel(model) ############################# # Criteria and Optimizers # ############################# id_criterion = nn.CrossEntropyLoss() g_tri_loss = TripletLoss(margin=cfg.global_margin) l_tri_loss = TripletLoss(margin=cfg.local_margin) optimizer = optim.Adam(model.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay) # Bind them together just to save some codes in the following usage. modules_optims = [model, optimizer] ################################ # May Resume Models and Optims # ################################ if cfg.resume: resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file) # May Transfer Models and Optims to Specified Device. Transferring optimizer # is to cope with the case when you load the checkpoint to a new device. TMO(modules_optims) ######## # Test # ######## def test(load_model_weight=False): if load_model_weight: if cfg.model_weight_file != '': map_location = (lambda storage, loc: storage) sd = torch.load(cfg.model_weight_file, map_location=map_location) load_state_dict(model, sd) print('Loaded model weights from {}'.format( cfg.model_weight_file)) else: load_ckpt(modules_optims, cfg.ckpt_file) use_local_distance = (cfg.l_loss_weight > 0) \ and cfg.local_dist_own_hard_sample for test_set, name in zip(test_sets, test_set_names): test_set.set_feat_func(ExtractFeature(model_w, TVT)) print('\n=========> Test on dataset: {} <=========\n'.format(name)) test_set.eval(normalize_feat=cfg.normalize_feature, use_local_distance=use_local_distance) # reranking_mAP_list.append(mAP) if cfg.only_test: test(load_model_weight=True) return ############ # Training # ############ start_ep = resume_ep if cfg.resume else 0 for ep in range(start_ep, cfg.total_epochs): # Adjust Learning Rate if cfg.lr_decay_type == 'exp': adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs, cfg.exp_decay_at_epoch) else: adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1, cfg.staircase_decay_at_epochs, cfg.staircase_decay_multiply_factor) may_set_mode(modules_optims, 'train') g_prec_meter = AverageMeter() g_m_meter = AverageMeter() g_dist_ap_meter = AverageMeter() g_dist_an_meter = AverageMeter() g_loss_meter = AverageMeter() l_prec_meter = AverageMeter() l_m_meter = AverageMeter() l_dist_ap_meter = AverageMeter() l_dist_an_meter = AverageMeter() l_loss_meter = AverageMeter() id_loss_meter = AverageMeter() loss_meter = AverageMeter() ep_st = time.time() step = 0 epoch_done = False while not epoch_done: step += 1 step_st = time.time() ims, im_names, labels, mirrored, epoch_done = train_set.next_batch( ) ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) labels_var = Variable(labels_t) global_feat, local_feat, logits = model_w(ims_var) g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss( g_tri_loss, global_feat, labels_t, normalize_feature=cfg.normalize_feature) if cfg.l_loss_weight == 0: l_loss = 0 elif cfg.local_dist_own_hard_sample: # Let local distance find its own hard samples. l_loss, l_dist_ap, l_dist_an, _ = local_loss( l_tri_loss, local_feat, None, None, labels_t, normalize_feature=cfg.normalize_feature) else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t, normalize_feature=cfg.normalize_feature) id_loss = 0 if cfg.id_loss_weight > 0: id_loss = id_criterion(logits, labels_var) loss = g_loss * cfg.g_loss_weight \ + l_loss * cfg.l_loss_weight \ + id_loss * cfg.id_loss_weight optimizer.zero_grad() loss.backward() optimizer.step() ############ # Step Log # ############ # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) if cfg.id_loss_weight > 0: id_loss_meter.update(to_scalar(id_loss)) loss_meter.update(to_scalar(loss)) if step % cfg.log_steps == 0: time_log = '\tStep {}/Ep {}, {:.2f}s'.format( step, ep + 1, time.time() - step_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.val, g_m_meter.val, g_dist_ap_meter.val, g_dist_an_meter.val, g_loss_meter.val, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.val, l_m_meter.val, l_dist_ap_meter.val, l_dist_an_meter.val, l_loss_meter.val, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.val)) else: id_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.val) log = time_log + \ g_log + l_log + id_log + \ total_loss_log print(log) ############# # Epoch Log # ############# time_log = 'Ep {}, {:.2f}s'.format( ep + 1, time.time() - ep_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.avg, g_m_meter.avg, g_dist_ap_meter.avg, g_dist_an_meter.avg, g_loss_meter.avg, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.avg, l_m_meter.avg, l_dist_ap_meter.avg, l_dist_an_meter.avg, l_loss_meter.avg, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.avg)) else: id_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.avg) log = time_log + \ g_log + l_log + id_log + \ total_loss_log print(log) # Log to TensorBoard if cfg.log_to_file: if writer is None: writer = SummaryWriter( log_dir=osp.join(cfg.exp_dir, 'tensorboard')) writer.add_scalars( 'loss', dict( global_loss=g_loss_meter.avg, local_loss=l_loss_meter.avg, id_loss=id_loss_meter.avg, loss=loss_meter.avg, ), ep) writer.add_scalars( 'tri_precision', dict( global_precision=g_prec_meter.avg, local_precision=l_prec_meter.avg, ), ep) writer.add_scalars( 'satisfy_margin', dict( global_satisfy_margin=g_m_meter.avg, local_satisfy_margin=l_m_meter.avg, ), ep) writer.add_scalars( 'global_dist', dict( global_dist_ap=g_dist_ap_meter.avg, global_dist_an=g_dist_an_meter.avg, ), ep) writer.add_scalars( 'local_dist', dict( local_dist_ap=l_dist_ap_meter.avg, local_dist_an=l_dist_an_meter.avg, ), ep) # save ckpt if cfg.log_to_file: save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file) #if (ep+1)%1==0: # test(load_model_weight=False) # if (ep+1)%10==0: # print(reranking_mAP_list) ######## # Test # ######## test(load_model_weight=False)
def main(): cfg = Config() # Redirect logs to both console and file. if cfg.log_to_file: ReDirectSTD(cfg.stdout_file, 'stdout', False) ReDirectSTD(cfg.stderr_file, 'stderr', False) # Lazily create SummaryWriter writer = None TVTs, TMOs, relative_device_ids = set_devices_for_ml(cfg.sys_device_ids) if cfg.seed is not None: set_seed(cfg.seed) # Dump the configurations to log. import pprint print('-' * 60) print('cfg.__dict__') pprint.pprint(cfg.__dict__) print('-' * 60) ########### # Dataset # ########### train_set = create_dataset(**cfg.train_set_kwargs) test_sets = [] test_set_names = [] if cfg.dataset == 'combined': for name in ['market1501', 'cuhk03', 'duke']: cfg.test_set_kwargs['name'] = name test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(name) else: test_sets.append(create_dataset(**cfg.test_set_kwargs)) test_set_names.append(cfg.dataset) ########### # Models # ########### models = [ Model(local_conv_out_channels=cfg.local_conv_out_channels, num_classes=len(train_set.ids2labels)) for _ in range(cfg.num_models) ] # Model wrappers model_ws = [ DataParallel(models[i], device_ids=relative_device_ids[i]) for i in range(cfg.num_models) ] ############################# # Criteria and Optimizers # ############################# id_criterion = nn.CrossEntropyLoss() g_tri_loss = TripletLoss(margin=cfg.global_margin) l_tri_loss = TripletLoss(margin=cfg.local_margin) optimizers = [ optim.Adam(m.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay) for m in models ] # Bind them together just to save some codes in the following usage. modules_optims = models + optimizers ################################ # May Resume Models and Optims # ################################ if cfg.resume: resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file) # May Transfer Models and Optims to Specified Device. Transferring optimizers # is to cope with the case when you load the checkpoint to a new device. for TMO, model, optimizer in zip(TMOs, models, optimizers): TMO([model, optimizer]) ######## # Test # ######## # Test each model using different distance settings. def test(load_model_weight=False): if load_model_weight: load_ckpt(modules_optims, cfg.ckpt_file) use_local_distance = (cfg.l_loss_weight > 0) \ and cfg.local_dist_own_hard_sample for i, (model_w, TVT) in enumerate(zip(model_ws, TVTs)): for test_set, name in zip(test_sets, test_set_names): test_set.set_feat_func(ExtractFeature(model_w, TVT)) print( '\n=========> Test Model #{} on dataset: {} <=========\n'. format(i + 1, name)) test_set.eval(normalize_feat=cfg.normalize_feature, use_local_distance=use_local_distance) if cfg.only_test: test(load_model_weight=True) return ############ # Training # ############ # Storing things that can be accessed cross threads. ims_list = [None for _ in range(cfg.num_models)] labels_list = [None for _ in range(cfg.num_models)] done_list1 = [False for _ in range(cfg.num_models)] done_list2 = [False for _ in range(cfg.num_models)] probs_list = [None for _ in range(cfg.num_models)] g_dist_mat_list = [None for _ in range(cfg.num_models)] l_dist_mat_list = [None for _ in range(cfg.num_models)] # Two phases for each model: # 1) forward and single-model loss; # 2) further add mutual loss and backward. # The 2nd phase is only ready to start when the 1st is finished for # all models. run_event1 = threading.Event() run_event2 = threading.Event() # This event is meant to be set to stop threads. However, as I found, with # `daemon` set to true when creating threads, manually stopping is # unnecessary. I guess some main-thread variables required by sub-threads # are destroyed when the main thread ends, thus the sub-threads throw errors # and exit too. # Real reason should be further explored. exit_event = threading.Event() # The function to be called by threads. def thread_target(i): while not exit_event.isSet(): # If the run event is not set, the thread just waits. if not run_event1.wait(0.001): continue ###################################### # Phase 1: Forward and Separate Loss # ###################################### TVT = TVTs[i] model_w = model_ws[i] ims = ims_list[i] labels = labels_list[i] optimizer = optimizers[i] ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) labels_var = Variable(labels_t) global_feat, local_feat, logits = model_w(ims_var) probs = F.softmax(logits) log_probs = F.log_softmax(logits) g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss( g_tri_loss, global_feat, labels_t, normalize_feature=cfg.normalize_feature) if cfg.l_loss_weight == 0: l_loss, l_dist_mat = 0, 0 elif cfg.local_dist_own_hard_sample: # Let local distance find its own hard samples. l_loss, l_dist_ap, l_dist_an, l_dist_mat = local_loss( l_tri_loss, local_feat, None, None, labels_t, normalize_feature=cfg.normalize_feature) else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t, normalize_feature=cfg.normalize_feature) l_dist_mat = 0 id_loss = 0 if cfg.id_loss_weight > 0: id_loss = id_criterion(logits, labels_var) probs_list[i] = probs g_dist_mat_list[i] = g_dist_mat l_dist_mat_list[i] = l_dist_mat done_list1[i] = True # Wait for event to be set, meanwhile checking if need to exit. while True: phase2_ready = run_event2.wait(0.001) if exit_event.isSet(): return if phase2_ready: break ##################################### # Phase 2: Mutual Loss and Backward # ##################################### # Probability Mutual Loss (KL Loss) pm_loss = 0 if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): for j in range(cfg.num_models): if j != i: pm_loss += F.kl_div(log_probs, TVT(probs_list[j]).detach(), False) pm_loss /= 1. * (cfg.num_models - 1) * len(ims) # Global Distance Mutual Loss (L2 Loss) gdm_loss = 0 if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): for j in range(cfg.num_models): if j != i: gdm_loss += torch.sum( torch.pow( g_dist_mat - TVT(g_dist_mat_list[j]).detach(), 2)) gdm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims) # Local Distance Mutual Loss (L2 Loss) ldm_loss = 0 if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): for j in range(cfg.num_models): if j != i: ldm_loss += torch.sum( torch.pow( l_dist_mat - TVT(l_dist_mat_list[j]).detach(), 2)) ldm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims) loss = g_loss * cfg.g_loss_weight \ + l_loss * cfg.l_loss_weight \ + id_loss * cfg.id_loss_weight \ + pm_loss * cfg.pm_loss_weight \ + gdm_loss * cfg.gdm_loss_weight \ + ldm_loss * cfg.ldm_loss_weight optimizer.zero_grad() loss.backward() optimizer.step() ################################## # Step Log For One of the Models # ################################## # These meters are outer-scope variables # Just record for the first model if i == 0: # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) if cfg.id_loss_weight > 0: id_loss_meter.update(to_scalar(id_loss)) if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): pm_loss_meter.update(to_scalar(pm_loss)) if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): gdm_loss_meter.update(to_scalar(gdm_loss)) if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): ldm_loss_meter.update(to_scalar(ldm_loss)) loss_meter.update(to_scalar(loss)) ################### # End Up One Step # ################### run_event1.clear() run_event2.clear() done_list2[i] = True threads = [] for i in range(cfg.num_models): thread = threading.Thread(target=thread_target, args=(i, )) # Set the thread in daemon mode, so that the main program ends normally. thread.daemon = True thread.start() threads.append(thread) start_ep = resume_ep if cfg.resume else 0 for ep in range(start_ep, cfg.total_epochs): # Adjust Learning Rate for optimizer in optimizers: if cfg.lr_decay_type == 'exp': adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs, cfg.exp_decay_at_epoch) else: adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1, cfg.staircase_decay_at_epochs, cfg.staircase_decay_multiply_factor) may_set_mode(modules_optims, 'train') epoch_done = False g_prec_meter = AverageMeter() g_m_meter = AverageMeter() g_dist_ap_meter = AverageMeter() g_dist_an_meter = AverageMeter() g_loss_meter = AverageMeter() l_prec_meter = AverageMeter() l_m_meter = AverageMeter() l_dist_ap_meter = AverageMeter() l_dist_an_meter = AverageMeter() l_loss_meter = AverageMeter() id_loss_meter = AverageMeter() # Global Distance Mutual Loss gdm_loss_meter = AverageMeter() # Local Distance Mutual Loss ldm_loss_meter = AverageMeter() # Probability Mutual Loss pm_loss_meter = AverageMeter() loss_meter = AverageMeter() ep_st = time.time() step = 0 while not epoch_done: step += 1 step_st = time.time() ims, im_names, labels, mirrored, epoch_done = train_set.next_batch( ) for i in range(cfg.num_models): ims_list[i] = ims labels_list[i] = labels done_list1[i] = False done_list2[i] = False run_event1.set() # Waiting for phase 1 done while not all(done_list1): continue run_event2.set() # Waiting for phase 2 done while not all(done_list2): continue ############ # Step Log # ############ if step % cfg.log_steps == 0: time_log = '\tStep {}/Ep {}, {:.2f}s'.format( step, ep + 1, time.time() - step_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.val, g_m_meter.val, g_dist_ap_meter.val, g_dist_an_meter.val, g_loss_meter.val, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.val, l_m_meter.val, l_dist_ap_meter.val, l_dist_an_meter.val, l_loss_meter.val, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.val)) else: id_log = '' if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): pm_log = (', pmL {:.4f}'.format(pm_loss_meter.val)) else: pm_log = '' if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.val)) else: gdm_log = '' if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.val)) else: ldm_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.val) log = time_log + \ g_log + l_log + id_log + \ pm_log + gdm_log + ldm_log + \ total_loss_log print(log) ############# # Epoch Log # ############# time_log = 'Ep {}, {:.2f}s'.format( ep + 1, time.time() - ep_st, ) if cfg.g_loss_weight > 0: g_log = (', gp {:.2%}, gm {:.2%}, ' 'gd_ap {:.4f}, gd_an {:.4f}, ' 'gL {:.4f}'.format( g_prec_meter.avg, g_m_meter.avg, g_dist_ap_meter.avg, g_dist_an_meter.avg, g_loss_meter.avg, )) else: g_log = '' if cfg.l_loss_weight > 0: l_log = (', lp {:.2%}, lm {:.2%}, ' 'ld_ap {:.4f}, ld_an {:.4f}, ' 'lL {:.4f}'.format( l_prec_meter.avg, l_m_meter.avg, l_dist_ap_meter.avg, l_dist_an_meter.avg, l_loss_meter.avg, )) else: l_log = '' if cfg.id_loss_weight > 0: id_log = (', idL {:.4f}'.format(id_loss_meter.avg)) else: id_log = '' if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0): pm_log = (', pmL {:.4f}'.format(pm_loss_meter.avg)) else: pm_log = '' if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0): gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.avg)) else: gdm_log = '' if (cfg.num_models > 1) \ and cfg.local_dist_own_hard_sample \ and (cfg.ldm_loss_weight > 0): ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.avg)) else: ldm_log = '' total_loss_log = ', loss {:.4f}'.format(loss_meter.avg) log = time_log + \ g_log + l_log + id_log + \ pm_log + gdm_log + ldm_log + \ total_loss_log print(log) # Log to TensorBoard if cfg.log_to_file: if writer is None: writer = SummaryWriter( log_dir=osp.join(cfg.exp_dir, 'tensorboard')) writer.add_scalars( 'loss', dict( global_loss=g_loss_meter.avg, local_loss=l_loss_meter.avg, id_loss=id_loss_meter.avg, pm_loss=pm_loss_meter.avg, gdm_loss=gdm_loss_meter.avg, ldm_loss=ldm_loss_meter.avg, loss=loss_meter.avg, ), ep) writer.add_scalars( 'tri_precision', dict( global_precision=g_prec_meter.avg, local_precision=l_prec_meter.avg, ), ep) writer.add_scalars( 'satisfy_margin', dict( global_satisfy_margin=g_m_meter.avg, local_satisfy_margin=l_m_meter.avg, ), ep) writer.add_scalars( 'global_dist', dict( global_dist_ap=g_dist_ap_meter.avg, global_dist_an=g_dist_an_meter.avg, ), ep) writer.add_scalars( 'local_dist', dict( local_dist_ap=l_dist_ap_meter.avg, local_dist_an=l_dist_an_meter.avg, ), ep) # save ckpt if cfg.log_to_file: save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file) ######## # Test # ######## test(load_model_weight=False)
def main(): cfg = Config() # Redirect logs to both console and file. if cfg.log_to_file: ReDirectSTD(cfg.log_file, 'stdout', False) ReDirectSTD(cfg.log_err_file, 'stderr', False) TVT, TMO = set_devices(cfg.sys_device_ids) if cfg.seed is not None: set_seed(cfg.seed) # Dump the configurations to log. import pprint pprint.pprint(cfg.__dict__) if cfg.log_to_file: writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard')) else: writer = None ########### # Models # ########### model = Model(local_conv_out_channels=cfg.local_conv_out_channels) model_w = get_model_wrapper(model, len(cfg.sys_device_ids) > 1) ############################# # Criteria and Optimizers # ############################# g_tri_loss = TripletLoss(margin=cfg.global_margin) l_tri_loss = TripletLoss(margin=cfg.local_margin) optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) modules_optims = [model, optimizer] ################################ # May Resume Models and Optims # ################################ if cfg.resume: resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file) # May Transfer Models and Optims to Specified Device TMO(modules_optims) ########### # Dataset # ########### def feature_func(ims): """A function to be called in the val/test set, to extract features.""" # Set eval mode. # Force all BN layers to use global mean and variance, also disable # dropout. may_set_mode(modules_optims, 'eval') ims = Variable(TVT(torch.from_numpy(ims).float())) global_feat, local_feat = model_w(ims) global_feat = global_feat.data.cpu().numpy() local_feat = local_feat.data.cpu().numpy() return global_feat, local_feat train_set, val_set, test_set = None, None, None if not cfg.only_test: train_set = create_dataset(**cfg.train_set_kwargs) # val_set = create_dataset(**cfg.val_set_kwargs) # val_set.set_feat_func(feature_func) if cfg.only_test or cfg.test: test_set = create_dataset(**cfg.test_set_kwargs) test_set.set_feat_func(feature_func) ######## # Test # ######## if cfg.only_test: print('=====> Test') load_ckpt(modules_optims, cfg.ckpt_file) mAP, cmc_scores, mq_mAP, mq_cmc_scores = test_set.eval( normalize_feat=True, global_weight=cfg.g_test_weight, local_weight=cfg.l_test_weight) return ############ # Training # ############ best_score = scores if cfg.resume else 0 start_ep = resume_ep if cfg.resume else 0 for ep in range(start_ep, cfg.num_epochs): adjust_lr(optimizer, cfg.lr, ep, cfg.num_epochs, cfg.start_decay_epoch) may_set_mode(modules_optims, 'train') epoch_done = False g_prec_meter = AverageMeter() g_m_meter = AverageMeter() g_dist_ap_meter = AverageMeter() g_dist_an_meter = AverageMeter() g_loss_meter = AverageMeter() l_prec_meter = AverageMeter() l_m_meter = AverageMeter() l_dist_ap_meter = AverageMeter() l_dist_an_meter = AverageMeter() l_loss_meter = AverageMeter() loss_meter = AverageMeter() ep_st = time.time() step = 0 while not epoch_done: step += 1 step_st = time.time() ims, im_names, labels, mirrored, epoch_done = train_set.next_batch( ) ims_var = Variable(TVT(torch.from_numpy(ims).float())) labels_t = TVT(torch.from_numpy(labels).long()) global_feat, local_feat = model_w(ims_var) g_loss, p_inds, n_inds, g_dist_ap, g_dist_an = global_loss( g_tri_loss, global_feat, labels_t) if cfg.l_loss_weight == 0: l_loss, l_prec, l_m = 0, 0, 0 else: l_loss, l_dist_ap, l_dist_an = local_loss( l_tri_loss, local_feat, p_inds, n_inds, labels_t) # Let local distance find its own hard samples. # l_loss, l_dist_ap, l_dist_an = local_loss( # l_tri_loss, local_feat, None, None, labels_t) loss = g_loss * cfg.g_loss_weight + l_loss * cfg.l_loss_weight optimizer.zero_grad() loss.backward() optimizer.step() # Step logs # precision g_prec = (g_dist_an > g_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean() g_d_ap = g_dist_ap.data.mean() g_d_an = g_dist_an.data.mean() g_prec_meter.update(g_prec) g_m_meter.update(g_m) g_dist_ap_meter.update(g_d_ap) g_dist_an_meter.update(g_d_an) g_loss_meter.update(to_scalar(g_loss)) if cfg.l_loss_weight > 0: # precision l_prec = (l_dist_an > l_dist_ap).data.float().mean() # the proportion of triplets that satisfy margin l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean() l_d_ap = l_dist_ap.data.mean() l_d_an = l_dist_an.data.mean() l_prec_meter.update(l_prec) l_m_meter.update(l_m) l_dist_ap_meter.update(l_d_ap) l_dist_an_meter.update(l_d_an) l_loss_meter.update(to_scalar(l_loss)) loss_meter.update(to_scalar(loss)) if step % cfg.log_steps == 0: print( '\tStep {}/Ep {}, {:.2f}s, ' 'gp {:.4f}, gm {:.4f}, gd_ap {:.4f}, gd_an {:.4f}, g_loss {:.4f}, ' 'lp {:.4f}, lm {:.4f}, ld_ap {:.4f}, ld_an {:.4f}, l_loss {:.4f}, ' 'loss: {:.4f}'.format( step, ep + 1, time.time() - step_st, g_prec_meter.val, g_m_meter.val, g_dist_ap_meter.val, g_dist_an_meter.val, g_loss_meter.val, l_prec_meter.val, l_m_meter.val, l_dist_ap_meter.val, l_dist_an_meter.val, l_loss_meter.val, loss_meter.val)) # Epoch logs print( 'Ep {}, {:.2f}s, ' 'gp {:.4f}, gm {:.4f}, gd_ap {:.4f}, gd_an {:.4f}, g_loss {:.4f}, ' 'lp {:.4f}, lm {:.4f}, ld_ap {:.4f}, ld_an {:.4f}, l_loss {:.4f}, ' 'loss: {:.4f}'.format(ep + 1, time.time() - ep_st, g_prec_meter.avg, g_m_meter.avg, g_dist_ap_meter.avg, g_dist_an_meter.avg, g_loss_meter.avg, l_prec_meter.avg, l_m_meter.avg, l_dist_ap_meter.avg, l_dist_an_meter.avg, l_loss_meter.avg, loss_meter.avg)) if cfg.log_to_file: writer.add_scalars( 'loss', dict( global_loss=g_loss_meter.avg, local_loss=l_loss_meter.avg, loss=loss_meter.avg, ), ep) writer.add_scalars( 'tri_precision', dict( global_precision=g_prec_meter.avg, local_precision=l_prec_meter.avg, ), ep) writer.add_scalars( 'satisfy_margin', dict( global_proportion=g_m_meter.avg, local_proportion=l_m_meter.avg, ), ep) writer.add_scalars( 'global_dist', dict( global_dist_ap=g_dist_ap_meter.avg, global_dist_an=g_dist_an_meter.avg, ), ep) writer.add_scalars( 'local_dist', dict( local_dist_ap=l_dist_ap_meter.avg, local_dist_an=l_dist_an_meter.avg, ), ep) mAP = 0 # print('=====> Validation') # mAP, cmc_scores, mq_mAP, mq_cmc_scores = val_set.eval( # normalize_feat=True, # global_weight=cfg.g_test_weight, # local_weight=cfg.l_test_weight) # save ckpt if cfg.save_ckpt: save_ckpt(modules_optims, ep + 1, mAP, cfg.ckpt_file) # if mAP > best_score: # best_score = mAP # shutil.copy(cfg.ckpt_file, cfg.best_ckpt_file) ######## # Test # ######## if cfg.test: print('=====> Test') mAP, cmc_scores, mq_mAP, mq_cmc_scores = test_set.eval( normalize_feat=True, global_weight=cfg.g_test_weight, local_weight=cfg.l_test_weight)