def main():
  cfg = Config()

  # Redirect logs to both console and file.
  if cfg.log_to_file:
    ReDirectSTD(cfg.stdout_file, 'stdout', False)
    ReDirectSTD(cfg.stderr_file, 'stderr', False)

  # Lazily create SummaryWriter
  writer = None

  TVT, TMO = set_devices(cfg.sys_device_ids)

  if cfg.seed is not None:
    set_seed(cfg.seed)

  # Dump the configurations to log.
  import pprint
  print('-' * 60)
  print('cfg.__dict__')
  pprint.pprint(cfg.__dict__)
  print('-' * 60)

  ###########
  # Dataset #
  ###########

  train_set = create_dataset(**cfg.train_set_kwargs)

  test_sets = []
  test_set_names = []
  if cfg.dataset == 'combined':
    for name in ['market1501', 'cuhk03', 'duke']:
      cfg.test_set_kwargs['name'] = name
      test_sets.append(create_dataset(**cfg.test_set_kwargs))
      test_set_names.append(name)
  else:
    test_sets.append(create_dataset(**cfg.test_set_kwargs))
    test_set_names.append(cfg.dataset)

  ###########
  # Models  #
  ###########

  model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
                num_classes=len(train_set.ids2labels))
  # Model wrapper
  model_w = DataParallel(model)

  #############################
  # Criteria and Optimizers   #
  #############################

  id_criterion = nn.CrossEntropyLoss()
  g_tri_loss = TripletLoss(margin=cfg.global_margin)
  l_tri_loss = TripletLoss(margin=cfg.local_margin)

  optimizer = optim.Adam(model.parameters(),
                         lr=cfg.base_lr,
                         weight_decay=cfg.weight_decay)

  # Bind them together just to save some codes in the following usage.
  modules_optims = [model, optimizer]

  ################################
  # May Resume Models and Optims #
  ################################

  if cfg.resume:
    resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

  # May Transfer Models and Optims to Specified Device. Transferring optimizer
  # is to cope with the case when you load the checkpoint to a new device.
  TMO(modules_optims)

  ########
  # Test #
  ########

  def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    use_local_distance = (cfg.l_loss_weight > 0) \
                         and cfg.local_dist_own_hard_sample

    for test_set, name in zip(test_sets, test_set_names):
      test_set.set_feat_func(ExtractFeature(model_w, TVT))
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=cfg.normalize_feature,
        use_local_distance=use_local_distance)

  if cfg.only_test:
    test(load_model_weight=True)
    return

  ############
  # Training #
  ############

  start_ep = resume_ep if cfg.resume else 0
  for ep in range(start_ep, cfg.total_epochs):

    # Adjust Learning Rate
    if cfg.lr_decay_type == 'exp':
      adjust_lr_exp(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.total_epochs,
        cfg.exp_decay_at_epoch)
    else:
      adjust_lr_staircase(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.staircase_decay_at_epochs,
        cfg.staircase_decay_multiply_factor)

    may_set_mode(modules_optims, 'train')

    g_prec_meter = AverageMeter()
    g_m_meter = AverageMeter()
    g_dist_ap_meter = AverageMeter()
    g_dist_an_meter = AverageMeter()
    g_loss_meter = AverageMeter()

    l_prec_meter = AverageMeter()
    l_m_meter = AverageMeter()
    l_dist_ap_meter = AverageMeter()
    l_dist_an_meter = AverageMeter()
    l_loss_meter = AverageMeter()

    id_loss_meter = AverageMeter()

    loss_meter = AverageMeter()

    ep_st = time.time()
    step = 0
    epoch_done = False
    while not epoch_done:

      step += 1
      step_st = time.time()

      ims, im_names, labels, mirrored, epoch_done = train_set.next_batch()

      ims_var = Variable(TVT(torch.from_numpy(ims).float()))
      labels_t = TVT(torch.from_numpy(labels).long())
      labels_var = Variable(labels_t)

      global_feat, local_feat, logits = model_w(ims_var)

      g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
        g_tri_loss, global_feat, labels_t,
        normalize_feature=cfg.normalize_feature)

      if cfg.l_loss_weight == 0:
        l_loss = 0
      elif cfg.local_dist_own_hard_sample:
        # Let local distance find its own hard samples.
        l_loss, l_dist_ap, l_dist_an, _ = local_loss(
          l_tri_loss, local_feat, None, None, labels_t,
          normalize_feature=cfg.normalize_feature)
      else:
        l_loss, l_dist_ap, l_dist_an = local_loss(
          l_tri_loss, local_feat, p_inds, n_inds, labels_t,
          normalize_feature=cfg.normalize_feature)

      id_loss = 0
      if cfg.id_loss_weight > 0:
        id_loss = id_criterion(logits, labels_var)

      loss = g_loss * cfg.g_loss_weight \
             + l_loss * cfg.l_loss_weight \
             + id_loss * cfg.id_loss_weight

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      ############
      # Step Log #
      ############

      # precision
      g_prec = (g_dist_an > g_dist_ap).data.float().mean()
      # the proportion of triplets that satisfy margin
      g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean()
      g_d_ap = g_dist_ap.data.mean()
      g_d_an = g_dist_an.data.mean()

      g_prec_meter.update(g_prec)
      g_m_meter.update(g_m)
      g_dist_ap_meter.update(g_d_ap)
      g_dist_an_meter.update(g_d_an)
      g_loss_meter.update(to_scalar(g_loss))

      if cfg.l_loss_weight > 0:
        # precision
        l_prec = (l_dist_an > l_dist_ap).data.float().mean()
        # the proportion of triplets that satisfy margin
        l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean()
        l_d_ap = l_dist_ap.data.mean()
        l_d_an = l_dist_an.data.mean()

        l_prec_meter.update(l_prec)
        l_m_meter.update(l_m)
        l_dist_ap_meter.update(l_d_ap)
        l_dist_an_meter.update(l_d_an)
        l_loss_meter.update(to_scalar(l_loss))

      if cfg.id_loss_weight > 0:
        id_loss_meter.update(to_scalar(id_loss))

      loss_meter.update(to_scalar(loss))

      if step % cfg.log_steps == 0:
        time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
          step, ep + 1, time.time() - step_st, )

        if cfg.g_loss_weight > 0:
          g_log = (', gp {:.2%}, gm {:.2%}, '
                   'gd_ap {:.4f}, gd_an {:.4f}, '
                   'gL {:.4f}'.format(
            g_prec_meter.val, g_m_meter.val,
            g_dist_ap_meter.val, g_dist_an_meter.val,
            g_loss_meter.val, ))
        else:
          g_log = ''

        if cfg.l_loss_weight > 0:
          l_log = (', lp {:.2%}, lm {:.2%}, '
                   'ld_ap {:.4f}, ld_an {:.4f}, '
                   'lL {:.4f}'.format(
            l_prec_meter.val, l_m_meter.val,
            l_dist_ap_meter.val, l_dist_an_meter.val,
            l_loss_meter.val, ))
        else:
          l_log = ''

        if cfg.id_loss_weight > 0:
          id_log = (', idL {:.4f}'.format(id_loss_meter.val))
        else:
          id_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

        log = time_log + \
              g_log + l_log + id_log + \
              total_loss_log
        print(log)

    #############
    # Epoch Log #
    #############

    time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st, )

    if cfg.g_loss_weight > 0:
      g_log = (', gp {:.2%}, gm {:.2%}, '
               'gd_ap {:.4f}, gd_an {:.4f}, '
               'gL {:.4f}'.format(
        g_prec_meter.avg, g_m_meter.avg,
        g_dist_ap_meter.avg, g_dist_an_meter.avg,
        g_loss_meter.avg, ))
    else:
      g_log = ''

    if cfg.l_loss_weight > 0:
      l_log = (', lp {:.2%}, lm {:.2%}, '
               'ld_ap {:.4f}, ld_an {:.4f}, '
               'lL {:.4f}'.format(
        l_prec_meter.avg, l_m_meter.avg,
        l_dist_ap_meter.avg, l_dist_an_meter.avg,
        l_loss_meter.avg, ))
    else:
      l_log = ''

    if cfg.id_loss_weight > 0:
      id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
    else:
      id_log = ''

    total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

    log = time_log + \
          g_log + l_log + id_log + \
          total_loss_log
    print(log)

    # Log to TensorBoard

    if cfg.log_to_file:
      if writer is None:
        writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
      writer.add_scalars(
        'loss',
        dict(global_loss=g_loss_meter.avg,
             local_loss=l_loss_meter.avg,
             id_loss=id_loss_meter.avg,
             loss=loss_meter.avg, ),
        ep)
      writer.add_scalars(
        'tri_precision',
        dict(global_precision=g_prec_meter.avg,
             local_precision=l_prec_meter.avg, ),
        ep)
      writer.add_scalars(
        'satisfy_margin',
        dict(global_satisfy_margin=g_m_meter.avg,
             local_satisfy_margin=l_m_meter.avg, ),
        ep)
      writer.add_scalars(
        'global_dist',
        dict(global_dist_ap=g_dist_ap_meter.avg,
             global_dist_an=g_dist_an_meter.avg, ),
        ep)
      writer.add_scalars(
        'local_dist',
        dict(local_dist_ap=l_dist_ap_meter.avg,
             local_dist_an=l_dist_an_meter.avg, ),
        ep)

    # save ckpt
    if cfg.log_to_file:
      save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

  ########
  # Test #
  ########

  test(load_model_weight=False)
def main():
  cfg = Config()

  # Redirect logs to both console and file.
  if cfg.log_to_file:
    ReDirectSTD(cfg.stdout_file, 'stdout', False)
    ReDirectSTD(cfg.stderr_file, 'stderr', False)

  # Lazily create SummaryWriter
  writer = None

  TVTs, TMOs, relative_device_ids = set_devices_for_ml(cfg.sys_device_ids)

  if cfg.seed is not None:
    set_seed(cfg.seed)

  # Dump the configurations to log.
  import pprint
  print('-' * 60)
  print('cfg.__dict__')
  pprint.pprint(cfg.__dict__)
  print('-' * 60)

  ###########
  # Dataset #
  ###########

  train_set = create_dataset(**cfg.train_set_kwargs)

  test_sets = []
  test_set_names = []
  if cfg.dataset == 'combined':
    for name in ['market1501', 'cuhk03', 'duke']:
      cfg.test_set_kwargs['name'] = name
      test_sets.append(create_dataset(**cfg.test_set_kwargs))
      test_set_names.append(name)
  else:
    test_sets.append(create_dataset(**cfg.test_set_kwargs))
    test_set_names.append(cfg.dataset)

  ###########
  # Models  #
  ###########

  models = [Model(local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=len(train_set.ids2labels))
            for _ in range(cfg.num_models)]
  # Model wrappers
  model_ws = [DataParallel(models[i], device_ids=relative_device_ids[i])
              for i in range(cfg.num_models)]

  #############################
  # Criteria and Optimizers   #
  #############################

  id_criterion = nn.CrossEntropyLoss()
  g_tri_loss = TripletLoss(margin=cfg.global_margin)
  l_tri_loss = TripletLoss(margin=cfg.local_margin)

  optimizers = [optim.Adam(m.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)
                for m in models]

  # Bind them together just to save some codes in the following usage.
  modules_optims = models + optimizers

  ################################
  # May Resume Models and Optims #
  ################################

  if cfg.resume:
    resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

  # May Transfer Models and Optims to Specified Device. Transferring optimizers
  # is to cope with the case when you load the checkpoint to a new device.
  for TMO, model, optimizer in zip(TMOs, models, optimizers):
    TMO([model, optimizer])

  ########
  # Test #
  ########

  # Test each model using different distance settings.
  def test(load_model_weight=False):
    if load_model_weight:
      load_ckpt(modules_optims, cfg.ckpt_file)

    use_local_distance = (cfg.l_loss_weight > 0) \
                         and cfg.local_dist_own_hard_sample

    for i, (model_w, TVT) in enumerate(zip(model_ws, TVTs)):
      for test_set, name in zip(test_sets, test_set_names):
        test_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n=========> Test Model #{} on dataset: {} <=========\n'
              .format(i + 1, name))
        test_set.eval(
          normalize_feat=cfg.normalize_feature,
          use_local_distance=use_local_distance)

  if cfg.only_test:
    test(load_model_weight=True)
    return

  ############
  # Training #
  ############

  # Storing things that can be accessed cross threads.

  ims_list = [None for _ in range(cfg.num_models)]
  labels_list = [None for _ in range(cfg.num_models)]

  done_list1 = [False for _ in range(cfg.num_models)]
  done_list2 = [False for _ in range(cfg.num_models)]

  probs_list = [None for _ in range(cfg.num_models)]
  g_dist_mat_list = [None for _ in range(cfg.num_models)]
  l_dist_mat_list = [None for _ in range(cfg.num_models)]

  # Two phases for each model:
  # 1) forward and single-model loss;
  # 2) further add mutual loss and backward.
  # The 2nd phase is only ready to start when the 1st is finished for
  # all models.
  run_event1 = threading.Event()
  run_event2 = threading.Event()

  # This event is meant to be set to stop threads. However, as I found, with
  # `daemon` set to true when creating threads, manually stopping is
  # unnecessary. I guess some main-thread variables required by sub-threads
  # are destroyed when the main thread ends, thus the sub-threads throw errors
  # and exit too.
  # Real reason should be further explored.
  exit_event = threading.Event()

  # The function to be called by threads.
  def thread_target(i):
    while not exit_event.isSet():
      # If the run event is not set, the thread just waits.
      if not run_event1.wait(0.001): continue

      ######################################
      # Phase 1: Forward and Separate Loss #
      ######################################

      TVT = TVTs[i]
      model_w = model_ws[i]
      ims = ims_list[i]
      labels = labels_list[i]
      optimizer = optimizers[i]

      ims_var = Variable(TVT(torch.from_numpy(ims).float()))
      labels_t = TVT(torch.from_numpy(labels).long())
      labels_var = Variable(labels_t)

      global_feat, local_feat, logits = model_w(ims_var)
      probs = F.softmax(logits)
      log_probs = F.log_softmax(logits)

      g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
        g_tri_loss, global_feat, labels_t,
        normalize_feature=cfg.normalize_feature)

      if cfg.l_loss_weight == 0:
        l_loss, l_dist_mat = 0, 0
      elif cfg.local_dist_own_hard_sample:
        # Let local distance find its own hard samples.
        l_loss, l_dist_ap, l_dist_an, l_dist_mat = local_loss(
          l_tri_loss, local_feat, None, None, labels_t,
          normalize_feature=cfg.normalize_feature)
      else:
        l_loss, l_dist_ap, l_dist_an = local_loss(
          l_tri_loss, local_feat, p_inds, n_inds, labels_t,
          normalize_feature=cfg.normalize_feature)
        l_dist_mat = 0

      id_loss = 0
      if cfg.id_loss_weight > 0:
        id_loss = id_criterion(logits, labels_var)

      probs_list[i] = probs
      g_dist_mat_list[i] = g_dist_mat
      l_dist_mat_list[i] = l_dist_mat

      done_list1[i] = True

      # Wait for event to be set, meanwhile checking if need to exit.
      while True:
        phase2_ready = run_event2.wait(0.001)
        if exit_event.isSet():
          return
        if phase2_ready:
          break

      #####################################
      # Phase 2: Mutual Loss and Backward #
      #####################################

      # Probability Mutual Loss (KL Loss)
      pm_loss = 0
      if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
        for j in range(cfg.num_models):
          if j != i:
            pm_loss += F.kl_div(log_probs, TVT(probs_list[j]).detach(), False)
        pm_loss /= 1. * (cfg.num_models - 1) * len(ims)

      # Global Distance Mutual Loss (L2 Loss)
      gdm_loss = 0
      if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
        for j in range(cfg.num_models):
          if j != i:
            gdm_loss += torch.sum(torch.pow(
              g_dist_mat - TVT(g_dist_mat_list[j]).detach(), 2))
        gdm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

      # Local Distance Mutual Loss (L2 Loss)
      ldm_loss = 0
      if (cfg.num_models > 1) \
          and cfg.local_dist_own_hard_sample \
          and (cfg.ldm_loss_weight > 0):
        for j in range(cfg.num_models):
          if j != i:
            ldm_loss += torch.sum(torch.pow(
              l_dist_mat - TVT(l_dist_mat_list[j]).detach(), 2))
        ldm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

      loss = g_loss * cfg.g_loss_weight \
             + l_loss * cfg.l_loss_weight \
             + id_loss * cfg.id_loss_weight \
             + pm_loss * cfg.pm_loss_weight \
             + gdm_loss * cfg.gdm_loss_weight \
             + ldm_loss * cfg.ldm_loss_weight

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      ##################################
      # Step Log For One of the Models #
      ##################################

      # These meters are outer-scope variables

      # Just record for the first model
      if i == 0:

        # precision
        g_prec = (g_dist_an > g_dist_ap).data.float().mean()
        # the proportion of triplets that satisfy margin
        g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean()
        g_d_ap = g_dist_ap.data.mean()
        g_d_an = g_dist_an.data.mean()

        g_prec_meter.update(g_prec)
        g_m_meter.update(g_m)
        g_dist_ap_meter.update(g_d_ap)
        g_dist_an_meter.update(g_d_an)
        g_loss_meter.update(to_scalar(g_loss))

        if cfg.l_loss_weight > 0:
          # precision
          l_prec = (l_dist_an > l_dist_ap).data.float().mean()
          # the proportion of triplets that satisfy margin
          l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean()
          l_d_ap = l_dist_ap.data.mean()
          l_d_an = l_dist_an.data.mean()

          l_prec_meter.update(l_prec)
          l_m_meter.update(l_m)
          l_dist_ap_meter.update(l_d_ap)
          l_dist_an_meter.update(l_d_an)
          l_loss_meter.update(to_scalar(l_loss))

        if cfg.id_loss_weight > 0:
          id_loss_meter.update(to_scalar(id_loss))

        if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
          pm_loss_meter.update(to_scalar(pm_loss))

        if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
          gdm_loss_meter.update(to_scalar(gdm_loss))

        if (cfg.num_models > 1) \
            and cfg.local_dist_own_hard_sample \
            and (cfg.ldm_loss_weight > 0):
          ldm_loss_meter.update(to_scalar(ldm_loss))

        loss_meter.update(to_scalar(loss))

      ###################
      # End Up One Step #
      ###################

      run_event1.clear()
      run_event2.clear()

      done_list2[i] = True

  threads = []
  for i in range(cfg.num_models):
    thread = threading.Thread(target=thread_target, args=(i,))
    # Set the thread in daemon mode, so that the main program ends normally.
    thread.daemon = True
    thread.start()
    threads.append(thread)

  start_ep = resume_ep if cfg.resume else 0
  for ep in range(start_ep, cfg.total_epochs):

    # Adjust Learning Rate
    for optimizer in optimizers:
      if cfg.lr_decay_type == 'exp':
        adjust_lr_exp(
          optimizer,
          cfg.base_lr,
          ep + 1,
          cfg.total_epochs,
          cfg.exp_decay_at_epoch)
      else:
        adjust_lr_staircase(
          optimizer,
          cfg.base_lr,
          ep + 1,
          cfg.staircase_decay_at_epochs,
          cfg.staircase_decay_multiply_factor)

    may_set_mode(modules_optims, 'train')

    epoch_done = False

    g_prec_meter = AverageMeter()
    g_m_meter = AverageMeter()
    g_dist_ap_meter = AverageMeter()
    g_dist_an_meter = AverageMeter()
    g_loss_meter = AverageMeter()

    l_prec_meter = AverageMeter()
    l_m_meter = AverageMeter()
    l_dist_ap_meter = AverageMeter()
    l_dist_an_meter = AverageMeter()
    l_loss_meter = AverageMeter()

    id_loss_meter = AverageMeter()

    # Global Distance Mutual Loss
    gdm_loss_meter = AverageMeter()
    # Local Distance Mutual Loss
    ldm_loss_meter = AverageMeter()
    # Probability Mutual Loss
    pm_loss_meter = AverageMeter()

    loss_meter = AverageMeter()

    ep_st = time.time()
    step = 0
    while not epoch_done:

      step += 1
      step_st = time.time()

      ims, im_names, labels, mirrored, epoch_done = train_set.next_batch()

      for i in range(cfg.num_models):
        ims_list[i] = ims
        labels_list[i] = labels
        done_list1[i] = False
        done_list2[i] = False

      run_event1.set()
      # Waiting for phase 1 done
      while not all(done_list1): continue

      run_event2.set()
      # Waiting for phase 2 done
      while not all(done_list2): continue

      ############
      # Step Log #
      ############

      if step % cfg.log_steps == 0:
        time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
          step, ep + 1, time.time() - step_st, )

        if cfg.g_loss_weight > 0:
          g_log = (', gp {:.2%}, gm {:.2%}, '
                   'gd_ap {:.4f}, gd_an {:.4f}, '
                   'gL {:.4f}'.format(
            g_prec_meter.val, g_m_meter.val,
            g_dist_ap_meter.val, g_dist_an_meter.val,
            g_loss_meter.val, ))
        else:
          g_log = ''

        if cfg.l_loss_weight > 0:
          l_log = (', lp {:.2%}, lm {:.2%}, '
                   'ld_ap {:.4f}, ld_an {:.4f}, '
                   'lL {:.4f}'.format(
            l_prec_meter.val, l_m_meter.val,
            l_dist_ap_meter.val, l_dist_an_meter.val,
            l_loss_meter.val, ))
        else:
          l_log = ''

        if cfg.id_loss_weight > 0:
          id_log = (', idL {:.4f}'.format(id_loss_meter.val))
        else:
          id_log = ''

        if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
          pm_log = (', pmL {:.4f}'.format(pm_loss_meter.val))
        else:
          pm_log = ''

        if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
          gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.val))
        else:
          gdm_log = ''

        if (cfg.num_models > 1) \
            and cfg.local_dist_own_hard_sample \
            and (cfg.ldm_loss_weight > 0):
          ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.val))
        else:
          ldm_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

        log = time_log + \
              g_log + l_log + id_log + \
              pm_log + gdm_log + ldm_log + \
              total_loss_log
        print(log)

    #############
    # Epoch Log #
    #############

    time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st, )

    if cfg.g_loss_weight > 0:
      g_log = (', gp {:.2%}, gm {:.2%}, '
               'gd_ap {:.4f}, gd_an {:.4f}, '
               'gL {:.4f}'.format(
        g_prec_meter.avg, g_m_meter.avg,
        g_dist_ap_meter.avg, g_dist_an_meter.avg,
        g_loss_meter.avg, ))
    else:
      g_log = ''

    if cfg.l_loss_weight > 0:
      l_log = (', lp {:.2%}, lm {:.2%}, '
               'ld_ap {:.4f}, ld_an {:.4f}, '
               'lL {:.4f}'.format(
        l_prec_meter.avg, l_m_meter.avg,
        l_dist_ap_meter.avg, l_dist_an_meter.avg,
        l_loss_meter.avg, ))
    else:
      l_log = ''

    if cfg.id_loss_weight > 0:
      id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
    else:
      id_log = ''

    if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
      pm_log = (', pmL {:.4f}'.format(pm_loss_meter.avg))
    else:
      pm_log = ''

    if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
      gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.avg))
    else:
      gdm_log = ''

    if (cfg.num_models > 1) \
        and cfg.local_dist_own_hard_sample \
        and (cfg.ldm_loss_weight > 0):
      ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.avg))
    else:
      ldm_log = ''

    total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

    log = time_log + \
          g_log + l_log + id_log + \
          pm_log + gdm_log + ldm_log + \
          total_loss_log
    print(log)

    # Log to TensorBoard

    if cfg.log_to_file:
      if writer is None:
        writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
      writer.add_scalars(
        'loss',
        dict(global_loss=g_loss_meter.avg,
             local_loss=l_loss_meter.avg,
             id_loss=id_loss_meter.avg,
             pm_loss=pm_loss_meter.avg,
             gdm_loss=gdm_loss_meter.avg,
             ldm_loss=ldm_loss_meter.avg,
             loss=loss_meter.avg, ),
        ep)
      writer.add_scalars(
        'tri_precision',
        dict(global_precision=g_prec_meter.avg,
             local_precision=l_prec_meter.avg, ),
        ep)
      writer.add_scalars(
        'satisfy_margin',
        dict(global_satisfy_margin=g_m_meter.avg,
             local_satisfy_margin=l_m_meter.avg, ),
        ep)
      writer.add_scalars(
        'global_dist',
        dict(global_dist_ap=g_dist_ap_meter.avg,
             global_dist_an=g_dist_an_meter.avg, ),
        ep)
      writer.add_scalars(
        'local_dist',
        dict(local_dist_ap=l_dist_ap_meter.avg,
             local_dist_an=l_dist_an_meter.avg, ),
        ep)

    # save ckpt
    if cfg.log_to_file:
      save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

  ########
  # Test #
  ########

  test(load_model_weight=False)
def main():
    # reranking_mAP_list = []
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=len(train_set.ids2labels))
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    id_criterion = nn.CrossEntropyLoss()
    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature,
                          use_local_distance=use_local_distance)
            # reranking_mAP_list.append(mAP)

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        id_loss_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            labels_var = Variable(labels_t)

            global_feat, local_feat, logits = model_w(ims_var)

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
                g_tri_loss,
                global_feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            if cfg.l_loss_weight == 0:
                l_loss = 0
            elif cfg.local_dist_own_hard_sample:
                # Let local distance find its own hard samples.
                l_loss, l_dist_ap, l_dist_an, _ = local_loss(
                    l_tri_loss,
                    local_feat,
                    None,
                    None,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss,
                    local_feat,
                    p_inds,
                    n_inds,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)

            id_loss = 0
            if cfg.id_loss_weight > 0:
                id_loss = id_criterion(logits, labels_var)

            loss = g_loss * cfg.g_loss_weight \
                   + l_loss * cfg.l_loss_weight \
                   + id_loss * cfg.id_loss_weight

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            # precision
            g_prec = (g_dist_an > g_dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            g_m = (g_dist_an >
                   g_dist_ap + cfg.global_margin).data.float().mean()
            g_d_ap = g_dist_ap.data.mean()
            g_d_an = g_dist_an.data.mean()

            g_prec_meter.update(g_prec)
            g_m_meter.update(g_m)
            g_dist_ap_meter.update(g_d_ap)
            g_dist_an_meter.update(g_d_an)
            g_loss_meter.update(to_scalar(g_loss))

            if cfg.l_loss_weight > 0:
                # precision
                l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                l_m = (l_dist_an >
                       l_dist_ap + cfg.local_margin).data.float().mean()
                l_d_ap = l_dist_ap.data.mean()
                l_d_an = l_dist_an.data.mean()

                l_prec_meter.update(l_prec)
                l_m_meter.update(l_m)
                l_dist_ap_meter.update(l_d_ap)
                l_dist_an_meter.update(l_d_an)
                l_loss_meter.update(to_scalar(l_loss))

            if cfg.id_loss_weight > 0:
                id_loss_meter.update(to_scalar(id_loss))

            loss_meter.update(to_scalar(loss))

            if step % cfg.log_steps == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                if cfg.g_loss_weight > 0:
                    g_log = (', gp {:.2%}, gm {:.2%}, '
                             'gd_ap {:.4f}, gd_an {:.4f}, '
                             'gL {:.4f}'.format(
                                 g_prec_meter.val,
                                 g_m_meter.val,
                                 g_dist_ap_meter.val,
                                 g_dist_an_meter.val,
                                 g_loss_meter.val,
                             ))
                else:
                    g_log = ''

                if cfg.l_loss_weight > 0:
                    l_log = (', lp {:.2%}, lm {:.2%}, '
                             'ld_ap {:.4f}, ld_an {:.4f}, '
                             'lL {:.4f}'.format(
                                 l_prec_meter.val,
                                 l_m_meter.val,
                                 l_dist_ap_meter.val,
                                 l_dist_an_meter.val,
                                 l_loss_meter.val,
                             ))
                else:
                    l_log = ''

                if cfg.id_loss_weight > 0:
                    id_log = (', idL {:.4f}'.format(id_loss_meter.val))
                else:
                    id_log = ''

                total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

                log = time_log + \
                      g_log + l_log + id_log + \
                      total_loss_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(
            ep + 1,
            time.time() - ep_st,
        )

        if cfg.g_loss_weight > 0:
            g_log = (', gp {:.2%}, gm {:.2%}, '
                     'gd_ap {:.4f}, gd_an {:.4f}, '
                     'gL {:.4f}'.format(
                         g_prec_meter.avg,
                         g_m_meter.avg,
                         g_dist_ap_meter.avg,
                         g_dist_an_meter.avg,
                         g_loss_meter.avg,
                     ))
        else:
            g_log = ''

        if cfg.l_loss_weight > 0:
            l_log = (', lp {:.2%}, lm {:.2%}, '
                     'ld_ap {:.4f}, ld_an {:.4f}, '
                     'lL {:.4f}'.format(
                         l_prec_meter.avg,
                         l_m_meter.avg,
                         l_dist_ap_meter.avg,
                         l_dist_an_meter.avg,
                         l_loss_meter.avg,
                     ))
        else:
            l_log = ''

        if cfg.id_loss_weight > 0:
            id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
        else:
            id_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

        log = time_log + \
              g_log + l_log + id_log + \
              total_loss_log
        print(log)

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    id_loss=id_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_satisfy_margin=g_m_meter.avg,
                    local_satisfy_margin=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)
        #if (ep+1)%1==0:
        #  test(load_model_weight=False)
        # if (ep+1)%10==0:
        #   print(reranking_mAP_list)

    ########
    # Test #
    ########

    test(load_model_weight=False)
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVTs, TMOs, relative_device_ids = set_devices_for_ml(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    models = [
        Model(local_conv_out_channels=cfg.local_conv_out_channels,
              num_classes=len(train_set.ids2labels))
        for _ in range(cfg.num_models)
    ]
    # Model wrappers
    model_ws = [
        DataParallel(models[i], device_ids=relative_device_ids[i])
        for i in range(cfg.num_models)
    ]

    #############################
    # Criteria and Optimizers   #
    #############################

    id_criterion = nn.CrossEntropyLoss()
    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)

    optimizers = [
        optim.Adam(m.parameters(),
                   lr=cfg.base_lr,
                   weight_decay=cfg.weight_decay) for m in models
    ]

    # Bind them together just to save some codes in the following usage.
    modules_optims = models + optimizers

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizers
    # is to cope with the case when you load the checkpoint to a new device.
    for TMO, model, optimizer in zip(TMOs, models, optimizers):
        TMO([model, optimizer])

    ########
    # Test #
    ########

    # Test each model using different distance settings.
    def test(load_model_weight=False):
        if load_model_weight:
            load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for i, (model_w, TVT) in enumerate(zip(model_ws, TVTs)):
            for test_set, name in zip(test_sets, test_set_names):
                test_set.set_feat_func(ExtractFeature(model_w, TVT))
                print(
                    '\n=========> Test Model #{} on dataset: {} <=========\n'.
                    format(i + 1, name))
                test_set.eval(normalize_feat=cfg.normalize_feature,
                              use_local_distance=use_local_distance)

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    # Storing things that can be accessed cross threads.

    ims_list = [None for _ in range(cfg.num_models)]
    labels_list = [None for _ in range(cfg.num_models)]

    done_list1 = [False for _ in range(cfg.num_models)]
    done_list2 = [False for _ in range(cfg.num_models)]

    probs_list = [None for _ in range(cfg.num_models)]
    g_dist_mat_list = [None for _ in range(cfg.num_models)]
    l_dist_mat_list = [None for _ in range(cfg.num_models)]

    # Two phases for each model:
    # 1) forward and single-model loss;
    # 2) further add mutual loss and backward.
    # The 2nd phase is only ready to start when the 1st is finished for
    # all models.
    run_event1 = threading.Event()
    run_event2 = threading.Event()

    # This event is meant to be set to stop threads. However, as I found, with
    # `daemon` set to true when creating threads, manually stopping is
    # unnecessary. I guess some main-thread variables required by sub-threads
    # are destroyed when the main thread ends, thus the sub-threads throw errors
    # and exit too.
    # Real reason should be further explored.
    exit_event = threading.Event()

    # The function to be called by threads.
    def thread_target(i):
        while not exit_event.isSet():
            # If the run event is not set, the thread just waits.
            if not run_event1.wait(0.001): continue

            ######################################
            # Phase 1: Forward and Separate Loss #
            ######################################

            TVT = TVTs[i]
            model_w = model_ws[i]
            ims = ims_list[i]
            labels = labels_list[i]
            optimizer = optimizers[i]

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            labels_var = Variable(labels_t)

            global_feat, local_feat, logits = model_w(ims_var)
            probs = F.softmax(logits)
            log_probs = F.log_softmax(logits)

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
                g_tri_loss,
                global_feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            if cfg.l_loss_weight == 0:
                l_loss, l_dist_mat = 0, 0
            elif cfg.local_dist_own_hard_sample:
                # Let local distance find its own hard samples.
                l_loss, l_dist_ap, l_dist_an, l_dist_mat = local_loss(
                    l_tri_loss,
                    local_feat,
                    None,
                    None,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss,
                    local_feat,
                    p_inds,
                    n_inds,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
                l_dist_mat = 0

            id_loss = 0
            if cfg.id_loss_weight > 0:
                id_loss = id_criterion(logits, labels_var)

            probs_list[i] = probs
            g_dist_mat_list[i] = g_dist_mat
            l_dist_mat_list[i] = l_dist_mat

            done_list1[i] = True

            # Wait for event to be set, meanwhile checking if need to exit.
            while True:
                phase2_ready = run_event2.wait(0.001)
                if exit_event.isSet():
                    return
                if phase2_ready:
                    break

            #####################################
            # Phase 2: Mutual Loss and Backward #
            #####################################

            # Probability Mutual Loss (KL Loss)
            pm_loss = 0
            if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
                for j in range(cfg.num_models):
                    if j != i:
                        pm_loss += F.kl_div(log_probs,
                                            TVT(probs_list[j]).detach(), False)
                pm_loss /= 1. * (cfg.num_models - 1) * len(ims)

            # Global Distance Mutual Loss (L2 Loss)
            gdm_loss = 0
            if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
                for j in range(cfg.num_models):
                    if j != i:
                        gdm_loss += torch.sum(
                            torch.pow(
                                g_dist_mat - TVT(g_dist_mat_list[j]).detach(),
                                2))
                gdm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

            # Local Distance Mutual Loss (L2 Loss)
            ldm_loss = 0
            if (cfg.num_models > 1) \
                and cfg.local_dist_own_hard_sample \
                and (cfg.ldm_loss_weight > 0):
                for j in range(cfg.num_models):
                    if j != i:
                        ldm_loss += torch.sum(
                            torch.pow(
                                l_dist_mat - TVT(l_dist_mat_list[j]).detach(),
                                2))
                ldm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

            loss = g_loss * cfg.g_loss_weight \
                   + l_loss * cfg.l_loss_weight \
                   + id_loss * cfg.id_loss_weight \
                   + pm_loss * cfg.pm_loss_weight \
                   + gdm_loss * cfg.gdm_loss_weight \
                   + ldm_loss * cfg.ldm_loss_weight

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ##################################
            # Step Log For One of the Models #
            ##################################

            # These meters are outer-scope variables

            # Just record for the first model
            if i == 0:

                # precision
                g_prec = (g_dist_an > g_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                g_m = (g_dist_an >
                       g_dist_ap + cfg.global_margin).data.float().mean()
                g_d_ap = g_dist_ap.data.mean()
                g_d_an = g_dist_an.data.mean()

                g_prec_meter.update(g_prec)
                g_m_meter.update(g_m)
                g_dist_ap_meter.update(g_d_ap)
                g_dist_an_meter.update(g_d_an)
                g_loss_meter.update(to_scalar(g_loss))

                if cfg.l_loss_weight > 0:
                    # precision
                    l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                    # the proportion of triplets that satisfy margin
                    l_m = (l_dist_an >
                           l_dist_ap + cfg.local_margin).data.float().mean()
                    l_d_ap = l_dist_ap.data.mean()
                    l_d_an = l_dist_an.data.mean()

                    l_prec_meter.update(l_prec)
                    l_m_meter.update(l_m)
                    l_dist_ap_meter.update(l_d_ap)
                    l_dist_an_meter.update(l_d_an)
                    l_loss_meter.update(to_scalar(l_loss))

                if cfg.id_loss_weight > 0:
                    id_loss_meter.update(to_scalar(id_loss))

                if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
                    pm_loss_meter.update(to_scalar(pm_loss))

                if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
                    gdm_loss_meter.update(to_scalar(gdm_loss))

                if (cfg.num_models > 1) \
                    and cfg.local_dist_own_hard_sample \
                    and (cfg.ldm_loss_weight > 0):
                    ldm_loss_meter.update(to_scalar(ldm_loss))

                loss_meter.update(to_scalar(loss))

            ###################
            # End Up One Step #
            ###################

            run_event1.clear()
            run_event2.clear()

            done_list2[i] = True

    threads = []
    for i in range(cfg.num_models):
        thread = threading.Thread(target=thread_target, args=(i, ))
        # Set the thread in daemon mode, so that the main program ends normally.
        thread.daemon = True
        thread.start()
        threads.append(thread)

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        for optimizer in optimizers:
            if cfg.lr_decay_type == 'exp':
                adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                              cfg.exp_decay_at_epoch)
            else:
                adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                    cfg.staircase_decay_at_epochs,
                                    cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        epoch_done = False

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        id_loss_meter = AverageMeter()

        # Global Distance Mutual Loss
        gdm_loss_meter = AverageMeter()
        # Local Distance Mutual Loss
        ldm_loss_meter = AverageMeter()
        # Probability Mutual Loss
        pm_loss_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            for i in range(cfg.num_models):
                ims_list[i] = ims
                labels_list[i] = labels
                done_list1[i] = False
                done_list2[i] = False

            run_event1.set()
            # Waiting for phase 1 done
            while not all(done_list1):
                continue

            run_event2.set()
            # Waiting for phase 2 done
            while not all(done_list2):
                continue

            ############
            # Step Log #
            ############

            if step % cfg.log_steps == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                if cfg.g_loss_weight > 0:
                    g_log = (', gp {:.2%}, gm {:.2%}, '
                             'gd_ap {:.4f}, gd_an {:.4f}, '
                             'gL {:.4f}'.format(
                                 g_prec_meter.val,
                                 g_m_meter.val,
                                 g_dist_ap_meter.val,
                                 g_dist_an_meter.val,
                                 g_loss_meter.val,
                             ))
                else:
                    g_log = ''

                if cfg.l_loss_weight > 0:
                    l_log = (', lp {:.2%}, lm {:.2%}, '
                             'ld_ap {:.4f}, ld_an {:.4f}, '
                             'lL {:.4f}'.format(
                                 l_prec_meter.val,
                                 l_m_meter.val,
                                 l_dist_ap_meter.val,
                                 l_dist_an_meter.val,
                                 l_loss_meter.val,
                             ))
                else:
                    l_log = ''

                if cfg.id_loss_weight > 0:
                    id_log = (', idL {:.4f}'.format(id_loss_meter.val))
                else:
                    id_log = ''

                if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
                    pm_log = (', pmL {:.4f}'.format(pm_loss_meter.val))
                else:
                    pm_log = ''

                if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
                    gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.val))
                else:
                    gdm_log = ''

                if (cfg.num_models > 1) \
                    and cfg.local_dist_own_hard_sample \
                    and (cfg.ldm_loss_weight > 0):
                    ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.val))
                else:
                    ldm_log = ''

                total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

                log = time_log + \
                      g_log + l_log + id_log + \
                      pm_log + gdm_log + ldm_log + \
                      total_loss_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(
            ep + 1,
            time.time() - ep_st,
        )

        if cfg.g_loss_weight > 0:
            g_log = (', gp {:.2%}, gm {:.2%}, '
                     'gd_ap {:.4f}, gd_an {:.4f}, '
                     'gL {:.4f}'.format(
                         g_prec_meter.avg,
                         g_m_meter.avg,
                         g_dist_ap_meter.avg,
                         g_dist_an_meter.avg,
                         g_loss_meter.avg,
                     ))
        else:
            g_log = ''

        if cfg.l_loss_weight > 0:
            l_log = (', lp {:.2%}, lm {:.2%}, '
                     'ld_ap {:.4f}, ld_an {:.4f}, '
                     'lL {:.4f}'.format(
                         l_prec_meter.avg,
                         l_m_meter.avg,
                         l_dist_ap_meter.avg,
                         l_dist_an_meter.avg,
                         l_loss_meter.avg,
                     ))
        else:
            l_log = ''

        if cfg.id_loss_weight > 0:
            id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
        else:
            id_log = ''

        if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
            pm_log = (', pmL {:.4f}'.format(pm_loss_meter.avg))
        else:
            pm_log = ''

        if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
            gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.avg))
        else:
            gdm_log = ''

        if (cfg.num_models > 1) \
            and cfg.local_dist_own_hard_sample \
            and (cfg.ldm_loss_weight > 0):
            ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.avg))
        else:
            ldm_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

        log = time_log + \
              g_log + l_log + id_log + \
              pm_log + gdm_log + ldm_log + \
              total_loss_log
        print(log)

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    id_loss=id_loss_meter.avg,
                    pm_loss=pm_loss_meter.avg,
                    gdm_loss=gdm_loss_meter.avg,
                    ldm_loss=ldm_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_satisfy_margin=g_m_meter.avg,
                    local_satisfy_margin=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.log_file, 'stdout', False)
        ReDirectSTD(cfg.log_err_file, 'stderr', False)

    TVT, TMO = set_devices(cfg.sys_device_ids)
    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    pprint.pprint(cfg.__dict__)

    if cfg.log_to_file:
        writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
    else:
        writer = None

    ###########
    # Models  #
    ###########

    model = Model(local_conv_out_channels=cfg.local_conv_out_channels)
    model_w = get_model_wrapper(model, len(cfg.sys_device_ids) > 1)

    #############################
    # Criteria and Optimizers   #
    #############################

    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.lr,
                           weight_decay=cfg.weight_decay)

    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device
    TMO(modules_optims)

    ###########
    # Dataset #
    ###########

    def feature_func(ims):
        """A function to be called in the val/test set, to extract features."""
        # Set eval mode.
        # Force all BN layers to use global mean and variance, also disable
        # dropout.
        may_set_mode(modules_optims, 'eval')
        ims = Variable(TVT(torch.from_numpy(ims).float()))
        global_feat, local_feat = model_w(ims)
        global_feat = global_feat.data.cpu().numpy()
        local_feat = local_feat.data.cpu().numpy()
        return global_feat, local_feat

    train_set, val_set, test_set = None, None, None
    if not cfg.only_test:
        train_set = create_dataset(**cfg.train_set_kwargs)
        # val_set = create_dataset(**cfg.val_set_kwargs)
        # val_set.set_feat_func(feature_func)
    if cfg.only_test or cfg.test:
        test_set = create_dataset(**cfg.test_set_kwargs)
        test_set.set_feat_func(feature_func)

    ########
    # Test #
    ########

    if cfg.only_test:
        print('=====> Test')
        load_ckpt(modules_optims, cfg.ckpt_file)
        mAP, cmc_scores, mq_mAP, mq_cmc_scores = test_set.eval(
            normalize_feat=True,
            global_weight=cfg.g_test_weight,
            local_weight=cfg.l_test_weight)
        return

    ############
    # Training #
    ############

    best_score = scores if cfg.resume else 0
    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.num_epochs):
        adjust_lr(optimizer, cfg.lr, ep, cfg.num_epochs, cfg.start_decay_epoch)
        may_set_mode(modules_optims, 'train')

        epoch_done = False

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            global_feat, local_feat = model_w(ims_var)

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an = global_loss(
                g_tri_loss, global_feat, labels_t)

            if cfg.l_loss_weight == 0:
                l_loss, l_prec, l_m = 0, 0, 0
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss, local_feat, p_inds, n_inds, labels_t)

                # Let local distance find its own hard samples.
                # l_loss, l_dist_ap, l_dist_an = local_loss(
                #   l_tri_loss, local_feat, None, None, labels_t)

            loss = g_loss * cfg.g_loss_weight + l_loss * cfg.l_loss_weight

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            # Step logs

            # precision
            g_prec = (g_dist_an > g_dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            g_m = (g_dist_an >
                   g_dist_ap + cfg.global_margin).data.float().mean()
            g_d_ap = g_dist_ap.data.mean()
            g_d_an = g_dist_an.data.mean()

            g_prec_meter.update(g_prec)
            g_m_meter.update(g_m)
            g_dist_ap_meter.update(g_d_ap)
            g_dist_an_meter.update(g_d_an)
            g_loss_meter.update(to_scalar(g_loss))

            if cfg.l_loss_weight > 0:
                # precision
                l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                l_m = (l_dist_an >
                       l_dist_ap + cfg.local_margin).data.float().mean()
                l_d_ap = l_dist_ap.data.mean()
                l_d_an = l_dist_an.data.mean()

                l_prec_meter.update(l_prec)
                l_m_meter.update(l_m)
                l_dist_ap_meter.update(l_d_ap)
                l_dist_an_meter.update(l_d_an)
                l_loss_meter.update(to_scalar(l_loss))

            loss_meter.update(to_scalar(loss))

            if step % cfg.log_steps == 0:
                print(
                    '\tStep {}/Ep {}, {:.2f}s, '
                    'gp {:.4f}, gm {:.4f}, gd_ap {:.4f}, gd_an {:.4f}, g_loss {:.4f}, '
                    'lp {:.4f}, lm {:.4f}, ld_ap {:.4f}, ld_an {:.4f}, l_loss {:.4f}, '
                    'loss: {:.4f}'.format(
                        step, ep + 1,
                        time.time() - step_st, g_prec_meter.val, g_m_meter.val,
                        g_dist_ap_meter.val, g_dist_an_meter.val,
                        g_loss_meter.val, l_prec_meter.val, l_m_meter.val,
                        l_dist_ap_meter.val, l_dist_an_meter.val,
                        l_loss_meter.val, loss_meter.val))

        # Epoch logs
        print(
            'Ep {}, {:.2f}s, '
            'gp {:.4f}, gm {:.4f}, gd_ap {:.4f}, gd_an {:.4f}, g_loss {:.4f}, '
            'lp {:.4f}, lm {:.4f}, ld_ap {:.4f}, ld_an {:.4f}, l_loss {:.4f}, '
            'loss: {:.4f}'.format(ep + 1,
                                  time.time() - ep_st, g_prec_meter.avg,
                                  g_m_meter.avg, g_dist_ap_meter.avg,
                                  g_dist_an_meter.avg, g_loss_meter.avg,
                                  l_prec_meter.avg, l_m_meter.avg,
                                  l_dist_ap_meter.avg, l_dist_an_meter.avg,
                                  l_loss_meter.avg, loss_meter.avg))

        if cfg.log_to_file:
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_proportion=g_m_meter.avg,
                    local_proportion=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        mAP = 0
        # print('=====> Validation')
        # mAP, cmc_scores, mq_mAP, mq_cmc_scores = val_set.eval(
        #   normalize_feat=True,
        #   global_weight=cfg.g_test_weight,
        #   local_weight=cfg.l_test_weight)

        # save ckpt
        if cfg.save_ckpt:
            save_ckpt(modules_optims, ep + 1, mAP, cfg.ckpt_file)
            # if mAP > best_score:
            #   best_score = mAP
            #   shutil.copy(cfg.ckpt_file, cfg.best_ckpt_file)

    ########
    # Test #
    ########

    if cfg.test:
        print('=====> Test')
        mAP, cmc_scores, mq_mAP, mq_cmc_scores = test_set.eval(
            normalize_feat=True,
            global_weight=cfg.g_test_weight,
            local_weight=cfg.l_test_weight)