Example #1
0
  def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    use_local_distance = (cfg.l_loss_weight > 0) \
                         and cfg.local_dist_own_hard_sample

    for test_set, name in zip(test_sets, test_set_names):
      test_set.set_feat_func(ExtractFeature(model_w, TVT))
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=cfg.normalize_feature,
        use_local_distance=use_local_distance)
Example #2
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=len(train_set.ids2labels))
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################
    #id_criterion = nn.CrossEntropyLoss()
    id_criterion = SoftmaxEntropyLoss()
    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature,
                          use_local_distance=use_local_distance)

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        id_loss_meter = AverageMeter()

        sift_loss_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, cam_lables, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            labels_var = Variable(labels_t)

            feat, global_feat, local_feat, logits = model_w(ims_var)
            sift_func = ExtractSift()
            sift = torch.from_numpy(sift_func(ims_var)).cuda()

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
                g_tri_loss,
                global_feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            if cfg.l_loss_weight == 0:
                l_loss = 0
            elif cfg.local_dist_own_hard_sample:
                # Let local distance find its own hard samples.
                l_loss, l_dist_ap, l_dist_an, _ = local_loss(
                    l_tri_loss,
                    local_feat,
                    None,
                    None,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss,
                    local_feat,
                    p_inds,
                    n_inds,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)

            id_loss = 0
            if cfg.id_loss_weight > 0:
                id_loss = id_criterion(logits, labels_var)

            sift_loss = 0
            if cfg.sift_loss_weight > 0:
                sift_loss = torch.norm(
                    F.softmax(global_feat, dim=1) - F.softmax(sift, dim=1))

            loss = g_loss * cfg.g_loss_weight \
                   + l_loss * cfg.l_loss_weight \
                   + id_loss * cfg.id_loss_weight \
                   + sift_loss * cfg.sift_loss_weight

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            # precision
            g_prec = (g_dist_an > g_dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            g_m = (g_dist_an >
                   g_dist_ap + cfg.global_margin).data.float().mean()
            g_d_ap = g_dist_ap.data.mean()
            g_d_an = g_dist_an.data.mean()

            g_prec_meter.update(g_prec)
            g_m_meter.update(g_m)
            g_dist_ap_meter.update(g_d_ap)
            g_dist_an_meter.update(g_d_an)
            g_loss_meter.update(to_scalar(g_loss))

            if cfg.l_loss_weight > 0:
                # precision
                l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                l_m = (l_dist_an >
                       l_dist_ap + cfg.local_margin).data.float().mean()
                l_d_ap = l_dist_ap.data.mean()
                l_d_an = l_dist_an.data.mean()

                l_prec_meter.update(l_prec)
                l_m_meter.update(l_m)
                l_dist_ap_meter.update(l_d_ap)
                l_dist_an_meter.update(l_d_an)
                l_loss_meter.update(to_scalar(l_loss))

            if cfg.id_loss_weight > 0:
                id_loss_meter.update(to_scalar(id_loss))

            if cfg.sift_loss_weight > 0:
                sift_loss_meter.update(to_scalar(sift_loss))

            loss_meter.update(to_scalar(loss))

            if step % cfg.log_steps == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                if cfg.g_loss_weight > 0:
                    g_log = (', gp {:.2%}, gm {:.2%}, '
                             'gd_ap {:.4f}, gd_an {:.4f}, '
                             'gL {:.4f}'.format(
                                 g_prec_meter.val,
                                 g_m_meter.val,
                                 g_dist_ap_meter.val,
                                 g_dist_an_meter.val,
                                 g_loss_meter.val,
                             ))
                else:
                    g_log = ''

                if cfg.l_loss_weight > 0:
                    l_log = (', lp {:.2%}, lm {:.2%}, '
                             'ld_ap {:.4f}, ld_an {:.4f}, '
                             'lL {:.4f}'.format(
                                 l_prec_meter.val,
                                 l_m_meter.val,
                                 l_dist_ap_meter.val,
                                 l_dist_an_meter.val,
                                 l_loss_meter.val,
                             ))
                else:
                    l_log = ''

                if cfg.id_loss_weight > 0:
                    id_log = (', idL {:.4f}'.format(id_loss_meter.val))
                else:
                    id_log = ''

                if cfg.sift_loss_weight > 0:
                    sift_log = (', sL {:.4f}'.format(sift_loss_meter.val))
                else:
                    sift_log = ''

                total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

                log = time_log + \
                      g_log + l_log + id_log + \
                      sift_log + total_loss_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(
            ep + 1,
            time.time() - ep_st,
        )

        if cfg.g_loss_weight > 0:
            g_log = (', gp {:.2%}, gm {:.2%}, '
                     'gd_ap {:.4f}, gd_an {:.4f}, '
                     'gL {:.4f}'.format(
                         g_prec_meter.avg,
                         g_m_meter.avg,
                         g_dist_ap_meter.avg,
                         g_dist_an_meter.avg,
                         g_loss_meter.avg,
                     ))
        else:
            g_log = ''

        if cfg.l_loss_weight > 0:
            l_log = (', lp {:.2%}, lm {:.2%}, '
                     'ld_ap {:.4f}, ld_an {:.4f}, '
                     'lL {:.4f}'.format(
                         l_prec_meter.avg,
                         l_m_meter.avg,
                         l_dist_ap_meter.avg,
                         l_dist_an_meter.avg,
                         l_loss_meter.avg,
                     ))
        else:
            l_log = ''

        if cfg.id_loss_weight > 0:
            id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
        else:
            id_log = ''

        if cfg.sift_loss_weight > 0:
            sift_log = (', sL {:.4f}'.format(sift_loss_meter.avg))
        else:
            sift_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

        log = time_log + \
              g_log + l_log + id_log + \
              sift_log + total_loss_log
        print(log)

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    id_loss=id_loss_meter.avg,
                    sift_loss=sift_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_satisfy_margin=g_m_meter.avg,
                    local_satisfy_margin=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
Example #3
0
def main():
  cfg = Config()

  # Redirect logs to both console and file.
  if cfg.log_to_file:
    ReDirectSTD(cfg.stdout_file, 'stdout', False)
    ReDirectSTD(cfg.stderr_file, 'stderr', False)

  # Lazily create SummaryWriter
  writer = None

  TVT, TMO = set_devices(cfg.sys_device_ids)

  if cfg.seed is not None:
    set_seed(cfg.seed)

  # Dump the configurations to log.
  import pprint
  print('-' * 60)
  print('cfg.__dict__')
  pprint.pprint(cfg.__dict__)
  print('-' * 60)

  ###########
  # Dataset #
  ###########

  train_set = create_dataset(**cfg.train_set_kwargs)

  test_sets = []
  test_set_names = []
  if cfg.dataset == 'combined':
    for name in ['market1501', 'cuhk03', 'duke']:
      cfg.test_set_kwargs['name'] = name
      test_sets.append(create_dataset(**cfg.test_set_kwargs))
      test_set_names.append(name)
  else:
    test_sets.append(create_dataset(**cfg.test_set_kwargs))
    test_set_names.append(cfg.dataset)

  ###########
  # Models  #
  ###########
  if cfg.dataset == 'market1501':
    cams = 6
  elif cfg.dataset == 'cuhk03':
    cams = 2
  else:
    cams = 8

  ids = len(train_set.ids2labels)
  
  model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
                num_classes=len(train_set.ids2labels), cam_classes= cams)
  # Model wrapper
  model_w = DataParallel(model)

  #############################
  # Criteria and Optimizers   #
  #############################

  view_criterion = SoftmaxEntropyLoss()
  id_criterion = nn.CrossEntropyLoss()
  id_criterion1 = SoftmaxEntropyLoss()
  id_criterion2 = SoftmaxEntropyLoss()  
  g_tri_loss = TripletLoss(margin=cfg.global_margin)
  l_tri_loss = TripletLoss(margin=cfg.local_margin)
  center_loss = CenterLoss(num_classes=len(train_set.ids2labels), feat_dim=2048, use_gpu=True)

  optimizer = optim.Adam(model.parameters(),
                         lr=cfg.base_lr,
                         weight_decay=cfg.weight_decay)
  optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.001)

  # Bind them together just to save some codes in the following usage.
  modules_optims = [model, optimizer]

  ################################
  # May Resume Models and Optims #
  ################################

  if cfg.resume:
    resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

  # May Transfer Models and Optims to Specified Device. Transferring optimizer
  # is to cope with the case when you load the checkpoint to a new device.
  TMO(modules_optims)

  ########
  # Test #
  ########

  def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    use_local_distance = (cfg.l_loss_weight > 0) \
                         and cfg.local_dist_own_hard_sample

    for test_set, name in zip(test_sets, test_set_names):
      test_set.set_feat_func(ExtractFeature(model_w, TVT))
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=cfg.normalize_feature,
        use_local_distance=use_local_distance)

  if cfg.only_test:
    test(load_model_weight=True)
    return

  ############
  # Training #
  ############

  start_ep = resume_ep if cfg.resume else 0
  for ep in range(start_ep, cfg.total_epochs):

    # Adjust Learning Rate
    if cfg.lr_decay_type == 'exp':
      adjust_lr_exp(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.total_epochs,
        cfg.exp_decay_at_epoch)
    if cfg.lr_decay_type == 'exp':
      adjust_lr_exp(
        optimizer_centloss,
        cfg.base_lr,
        ep + 1,
        cfg.total_epochs,
        cfg.exp_decay_at_epoch)  
    else:
      adjust_lr_staircase(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.staircase_decay_at_epochs,
        cfg.staircase_decay_multiply_factor)
      adjust_lr_staircase(
        optimizer_centloss,
        cfg.base_lr,
        ep + 1,
        cfg.staircase_decay_at_epochs,
        cfg.staircase_decay_multiply_factor)

    may_set_mode(modules_optims, 'train')

    g_prec_meter = AverageMeter()
    g_m_meter = AverageMeter()
    g_dist_ap_meter = AverageMeter()
    g_dist_an_meter = AverageMeter()
    g_loss_meter = AverageMeter()

    l_prec_meter = AverageMeter()
    l_m_meter = AverageMeter()
    l_dist_ap_meter = AverageMeter()
    l_dist_an_meter = AverageMeter()
    l_loss_meter = AverageMeter()

    id_loss_meter = AverageMeter()

    sift_loss_meter = AverageMeter()
    c_loss_meter = AverageMeter()

    a_loss_meter = AverageMeter()

    view_loss_meter = AverageMeter()
        
    loss_meter = AverageMeter()

    ep_st = time.time()
    step = 0
    epoch_done = False
    while not epoch_done:

      step += 1
      step_st = time.time()

      ims, im_names, labels, cam_labels, mirrored, epoch_done = train_set.next_batch()

      ims_var = Variable(TVT(torch.from_numpy(ims).float()))
      labels_t = TVT(torch.from_numpy(labels).long())
      labels_var1 = Variable(labels_t)
###########################################id labels########################################
      m = torch.LongTensor(labels)
      batchsize = cfg.ids_per_batch * cfg.ims_per_id
      n = m.view(batchsize,1)   #96

      id_onehot = torch.FloatTensor(batchsize, ids)
      id_onehot.zero_()
      id_onehot.scatter_(1, n, 1)

      id_onehot = id_onehot*0.8    #0.8
      id_po = 0.2 / (ids-1)
      id_piil = torch.zeros(batchsize,ids)
      id_pik = id_piil +  id_po
      labels_var = torch.where(id_onehot>id_pik,id_onehot,id_pik).cuda()

#########################################cam labels#####################################
      if cfg.dataset == 'cuhk03':
        cam_labels_t = TVT(torch.from_numpy(cam_labels).long())
      else:
        cam_labels = cam_labels - 1
        cam_labels_t = TVT(torch.from_numpy(cam_labels).long())

      b = torch.LongTensor(cam_labels)
      batchsize = cfg.ids_per_batch * cfg.ims_per_id
      c = b.view(batchsize,1)   #96

      cam_onehot = torch.FloatTensor(batchsize, cams)
      cam_onehot.zero_()
      cam_onehot.scatter_(1, c, 1)

      cam_onehot = cam_onehot*0.8    #0.8
      cam_po = 0.2 / (cams-1)
      cam_piil = torch.zeros(batchsize,cams)
      cam_pik = cam_piil +  cam_po
      cam_labels_var = torch.where(cam_onehot>cam_pik,cam_onehot,cam_pik).cuda()
      #cam_labels_var = Variable(cam_labels_t)

      feat, feat_part1, feat_part2, global_feat, local_feat, feature, logits, logits1, view_logits = model_w(ims_var)
      sift_func = ExtractSift()
      sift = torch.from_numpy(sift_func(ims_var)).cuda()

      part = {}
      num_part = 2
      for i in range(num_part):
          part[i] = feature[i]      

      #print(feat.size())  [90,2048,16,8]
      #print(feat_part1.size())  [90,512,16,8]    
      map1 = torch.mean(feat_part1,1).view(feat_part1.size(0),-1)
      map2 = torch.mean(feat_part2,1).view(feat_part2.size(0),-1)
      #print(map1.size())   [90,128]   

      g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
        g_tri_loss, global_feat, labels_t,
        normalize_feature=cfg.normalize_feature)

      if cfg.l_loss_weight == 0:
        l_loss = 0
      elif cfg.local_dist_own_hard_sample:
        # Let local distance find its own hard samples.
        l_loss, l_dist_ap, l_dist_an, _ = local_loss(
          l_tri_loss, local_feat, None, None, labels_t,
          normalize_feature=cfg.normalize_feature)
      else:
        l_loss, l_dist_ap, l_dist_an = local_loss(
          l_tri_loss, local_feat, p_inds, n_inds, labels_t,
          normalize_feature=cfg.normalize_feature)

      id_loss = 0
      if cfg.id_loss_weight > 0:
        id_loss =  id_criterion2(logits1, labels_var) 
        #id_loss =  id_criterion2(logits1, labels_var) + id_criterion1(logits, labels_var)

      a_loss = 0
      if cfg.align_loss_weight > 0:
        #a_loss = torch.norm(normalize(map1, axis=-1) - normalize(map2, axis=-1))
        a_loss = torch.norm(torch.sigmoid(map1) - torch.sigmoid(map2))
        
      sift_loss = 0
      if cfg.sift_loss_weight > 0:
        #sift_loss = torch.norm(normalize(part[2], axis=-1)-normalize(sift, axis=-1))
        sift_loss = torch.norm(F.softmax(global_feat,dim=1)-F.softmax(sift,dim=1))

      c_loss = 0
      if cfg.c_loss_weight > 0:
        c_loss = center_loss(normalize(global_feat, axis=-1), labels_var1)            

      view_loss = 0
      if cfg.view_loss_weight > 0:
        view_loss = view_criterion(view_logits, cam_labels_var)

      loss = g_loss * cfg.g_loss_weight \
             + l_loss * cfg.l_loss_weight \
             + id_loss * cfg.id_loss_weight \
             + sift_loss * cfg.sift_loss_weight \
             + c_loss * cfg.c_loss_weight \
             + view_loss * cfg.view_loss_weight \
             + a_loss * cfg.align_loss_weight       
      
      optimizer.zero_grad()
      optimizer_centloss.zero_grad()
      loss.backward()
      optimizer.step()
      for param in center_loss.parameters():
        param.grad.data *= (1 / cfg.c_loss_weight)
      optimizer_centloss.step()

      ############
      # Step Log #
      ############

      # precision
      g_prec = (g_dist_an > g_dist_ap).data.float().mean()
      # the proportion of triplets that satisfy margin
      g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean()
      g_d_ap = g_dist_ap.data.mean()
      g_d_an = g_dist_an.data.mean()

      g_prec_meter.update(g_prec)
      g_m_meter.update(g_m)
      g_dist_ap_meter.update(g_d_ap)
      g_dist_an_meter.update(g_d_an)
      g_loss_meter.update(to_scalar(g_loss))

      if cfg.l_loss_weight > 0:
        # precision
        l_prec = (l_dist_an > l_dist_ap).data.float().mean()
        # the proportion of triplets that satisfy margin
        l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean()
        l_d_ap = l_dist_ap.data.mean()
        l_d_an = l_dist_an.data.mean()

        l_prec_meter.update(l_prec)
        l_m_meter.update(l_m)
        l_dist_ap_meter.update(l_d_ap)
        l_dist_an_meter.update(l_d_an)
        l_loss_meter.update(to_scalar(l_loss))

      if cfg.id_loss_weight > 0:
        id_loss_meter.update(to_scalar(id_loss))

      if cfg.sift_loss_weight > 0:
        sift_loss_meter.update(to_scalar(sift_loss))

      if cfg.c_loss_weight > 0:
        c_loss_meter.update(to_scalar(c_loss))  

      if cfg.view_loss_weight > 0:
        view_loss_meter.update(to_scalar(view_loss))        

      if cfg.align_loss_weight > 0:
        a_loss_meter.update(to_scalar(a_loss))

      loss_meter.update(to_scalar(loss))

      if step % cfg.log_steps == 0:
        time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
          step, ep + 1, time.time() - step_st, )

        if cfg.g_loss_weight > 0:
          g_log = (', gp {:.2%}, gm {:.2%}, '
                   'gd_ap {:.4f}, gd_an {:.4f}, '
                   'gL {:.4f}'.format(
            g_prec_meter.val, g_m_meter.val,
            g_dist_ap_meter.val, g_dist_an_meter.val,
            g_loss_meter.val, ))
        else:
          g_log = ''

        if cfg.l_loss_weight > 0:
          l_log = (', lp {:.2%}, lm {:.2%}, '
                   'ld_ap {:.4f}, ld_an {:.4f}, '
                   'lL {:.4f}'.format(
            l_prec_meter.val, l_m_meter.val,
            l_dist_ap_meter.val, l_dist_an_meter.val,
            l_loss_meter.val, ))
        else:
          l_log = ''

        if cfg.id_loss_weight > 0:
          id_log = (', idL {:.4f}'.format(id_loss_meter.val))
        else:
          id_log = ''

        if cfg.sift_loss_weight > 0:
          sift_log = (', sL {:.4f}'.format(sift_loss_meter.val))
        else:
          sift_log = ''

        if cfg.c_loss_weight > 0:
          c_log = (', cL {:.4f}'.format(c_loss_meter.val))
        else:
          c_log = ''  

        if cfg.view_loss_weight > 0:
          view_log = (', viewL {:.4f}'.format(view_loss_meter.val))
        else:
          view_log = ''         

        if cfg.align_loss_weight > 0:
          a_log = (', aL {:.4f}'.format(a_loss_meter.val))
        else:
          a_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

        log = time_log + \
              g_log + l_log + id_log + a_log + \
              sift_log + c_log + view_log + \
              total_loss_log
        print(log)

    #############
    # Epoch Log #
    #############

    time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st, )

    if cfg.g_loss_weight > 0:
      g_log = (', gp {:.2%}, gm {:.2%}, '
               'gd_ap {:.4f}, gd_an {:.4f}, '
               'gL {:.4f}'.format(
        g_prec_meter.avg, g_m_meter.avg,
        g_dist_ap_meter.avg, g_dist_an_meter.avg,
        g_loss_meter.avg, ))
    else:
      g_log = ''

    if cfg.l_loss_weight > 0:
      l_log = (', lp {:.2%}, lm {:.2%}, '
               'ld_ap {:.4f}, ld_an {:.4f}, '
               'lL {:.4f}'.format(
        l_prec_meter.avg, l_m_meter.avg,
        l_dist_ap_meter.avg, l_dist_an_meter.avg,
        l_loss_meter.avg, ))
    else:
      l_log = ''

    if cfg.id_loss_weight > 0:
      id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
    else:
      id_log = ''

    if cfg.sift_loss_weight > 0:
      sift_log = (', sL {:.4f}'.format(sift_loss_meter.avg))
    else:
      sift_log = '' 

    if cfg.c_loss_weight > 0:
      c_log = (', cL {:.4f}'.format(c_loss_meter.avg))
    else:
      c_log = ''    

    if cfg.view_loss_weight > 0:
      view_log = (', viewL {:.4f}'.format(view_loss_meter.avg))
    else:
      view_log = ''      

    if cfg.align_loss_weight > 0:
      a_log = (', aL {:.4f}'.format(a_loss_meter.avg))
    else:
      a_log = ''

    total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

    log = time_log + \
          g_log + l_log + id_log + a_log + \
          sift_log + c_log + view_log + \
          total_loss_log
    print(log)

    # Log to TensorBoard

    if cfg.log_to_file:
      if writer is None:
        writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
      writer.add_scalars(
        'loss',
        dict(global_loss=g_loss_meter.avg,
             local_loss=l_loss_meter.avg,
             id_loss=id_loss_meter.avg,
             sift_loss=sift_loss_meter.avg,
             c_loss=c_loss_meter.avg, 
             view_loss=view_loss_meter.avg, 
             a_loss=a_loss_meter.avg,                                     
             loss=loss_meter.avg, ),
        ep)
      writer.add_scalars(
        'tri_precision',
        dict(global_precision=g_prec_meter.avg,
             local_precision=l_prec_meter.avg, ),
        ep)
      writer.add_scalars(
        'satisfy_margin',
        dict(global_satisfy_margin=g_m_meter.avg,
             local_satisfy_margin=l_m_meter.avg, ),
        ep)
      writer.add_scalars(
        'global_dist',
        dict(global_dist_ap=g_dist_ap_meter.avg,
             global_dist_an=g_dist_an_meter.avg, ),
        ep)
      writer.add_scalars(
        'local_dist',
        dict(local_dist_ap=l_dist_ap_meter.avg,
             local_dist_an=l_dist_an_meter.avg, ),
        ep)

    # save ckpt
    if cfg.log_to_file:
      save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

  ########
  # Test #
  ########

  test(load_model_weight=False)