コード例 #1
0
ファイル: train.py プロジェクト: fountaindream/VCFL
  def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    use_local_distance = (cfg.l_loss_weight > 0) \
                         and cfg.local_dist_own_hard_sample

    for test_set, name in zip(test_sets, test_set_names):
      test_set.set_feat_func(ExtractFeature(model_w, TVT))
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=cfg.normalize_feature,
        use_local_distance=use_local_distance)
コード例 #2
0
ファイル: train_whole.py プロジェクト: fountaindream/VCFL
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=len(train_set.ids2labels))
    # Model wrapper
    #model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    id_criterion = SoftmaxEntropyLoss()
    id_criterion1 = nn.CrossEntropyLoss()
    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)
    center_loss = CenterLoss(num_classes=len(train_set.ids2labels),
                             feat_dim=2048,
                             use_gpu=True)

    d_input_size = 2048
    d_hidden_size = 128
    d_output_size = 4
    d_learning_rate = 1e-4
    sgd_momentum = 0.9

    d_steps = 2
    g_steps = 1

    dfe, dre, ge = 0, 0, 0
    d_real_data, d_fake_data, g_fake_data = None, None, None
    discriminator_activation_function = nn.Sigmoid()

    G = Generator(model)
    D = Discriminator(input_size=d_input_size,
                      hidden_size=d_hidden_size,
                      output_size=d_output_size,
                      f=discriminator_activation_function)
    G_w = DataParallel(G)
    D_w = DataParallel(D)

    g_optimizer = optim.Adam(G.parameters(),
                             lr=cfg.base_lr,
                             weight_decay=cfg.weight_decay)
    d_optimizer = optim.SGD(D.parameters(),
                            lr=d_learning_rate,
                            momentum=sgd_momentum)

    optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.001)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, g_optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(G_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature,
                          use_local_distance=use_local_distance)

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(g_optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer_centloss, cfg.base_lr, ep + 1,
                          cfg.total_epochs, cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(g_optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)
            adjust_lr_staircase(optimizer_centloss, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        id_loss_meter = AverageMeter()

        sift_loss_meter = AverageMeter()
        c_loss_meter = AverageMeter()

        d_err_meter = AverageMeter()
        g_err_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            labels_var = Variable(labels_t)

            global_feat, local_feat, logits = G_w(ims_var)
            sift_func = ExtractSift()
            sift = torch.from_numpy(sift_func(ims_var)).cuda()

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
                g_tri_loss,
                global_feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            if cfg.l_loss_weight == 0:
                l_loss = 0
            elif cfg.local_dist_own_hard_sample:
                # Let local distance find its own hard samples.
                l_loss, l_dist_ap, l_dist_an, _ = local_loss(
                    l_tri_loss,
                    local_feat,
                    None,
                    None,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss,
                    local_feat,
                    p_inds,
                    n_inds,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)

            id_loss = 0
            if cfg.id_loss_weight > 0:
                id_loss = id_criterion1(logits, labels_var)

            sift_loss = 0
            if cfg.sift_loss_weight > 0:
                sift_loss = torch.norm(
                    normalize(global_feat, axis=-1) - normalize(sift, axis=-1))

            c_loss = 0
            if cfg.c_loss_weight > 0:
                c_loss = center_loss(normalize(global_feat, axis=-1),
                                     labels_var)


###########################################view classifier#################################
            batch_size = 48
            view_real_label_list = []
            view_fake_label_list = []
            view_label_array = np.array(
                [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
                dtype='float32')
            for m in range(batch_size):
                i = m % 4
                label_real = view_label_array[i, :]
                label_fake = np.array([0.25, 0.25, 0.25, 0.25],
                                      dtype='float32')
                view_real_label_list.append(label_real)
                view_fake_label_list.append(label_fake)
            real_label = np.vstack(view_real_label_list)
            view_real_label = torch.from_numpy(real_label).cuda()
            fake_label = np.vstack(view_fake_label_list)
            view_fake_label = torch.from_numpy(fake_label).cuda()

            g_error = 0
            if cfg.dg_loss_weight > 0:
                for d_index in range(d_steps):
                    # 1. Train D on real+fake
                    d_optimizer.zero_grad()

                    d_data = global_feat
                    d_decision = D_w(d_data)

                    d_error = id_criterion(d_decision,
                                           view_real_label)  # ones = true
                    d_error.backward(
                        retain_graph=True
                    )  # compute/store gradients, but don't change params
                    d_optimizer.step(
                    )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                # 2. Train G on D's response (but DO NOT train D on these labels)
                g_error = id_criterion(
                    d_decision,
                    view_fake_label)  # Train G to pretend it's genuine

            loss = g_loss * cfg.g_loss_weight \
                   + l_loss * cfg.l_loss_weight \
                   + id_loss * cfg.id_loss_weight \
                   + sift_loss * cfg.sift_loss_weight \
                   + c_loss * cfg.c_loss_weight \
                   + g_error * cfg.dg_loss_weight

            g_optimizer.zero_grad()
            optimizer_centloss.zero_grad()
            loss.backward(retain_graph=True)
            g_optimizer.step()
            for param in center_loss.parameters():
                param.grad.data *= (1 / cfg.c_loss_weight)
            optimizer_centloss.step()

            ############
            # Step Log #
            ############

            # precision
            g_prec = (g_dist_an > g_dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            g_m = (g_dist_an >
                   g_dist_ap + cfg.global_margin).data.float().mean()
            g_d_ap = g_dist_ap.data.mean()
            g_d_an = g_dist_an.data.mean()

            g_prec_meter.update(g_prec)
            g_m_meter.update(g_m)
            g_dist_ap_meter.update(g_d_ap)
            g_dist_an_meter.update(g_d_an)
            g_loss_meter.update(to_scalar(g_loss))

            if cfg.l_loss_weight > 0:
                # precision
                l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                l_m = (l_dist_an >
                       l_dist_ap + cfg.local_margin).data.float().mean()
                l_d_ap = l_dist_ap.data.mean()
                l_d_an = l_dist_an.data.mean()

                l_prec_meter.update(l_prec)
                l_m_meter.update(l_m)
                l_dist_ap_meter.update(l_d_ap)
                l_dist_an_meter.update(l_d_an)
                l_loss_meter.update(to_scalar(l_loss))

            if cfg.id_loss_weight > 0:
                id_loss_meter.update(to_scalar(id_loss))

            if cfg.sift_loss_weight > 0:
                sift_loss_meter.update(to_scalar(sift_loss))

            if cfg.c_loss_weight > 0:
                c_loss_meter.update(to_scalar(c_loss))

            if cfg.dg_loss_weight > 0:
                d_err_meter.update(to_scalar(d_error))
                g_err_meter.update(to_scalar(g_error))

            loss_meter.update(to_scalar(loss))

            if step % cfg.log_steps == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                if cfg.g_loss_weight > 0:
                    g_log = (', gp {:.2%}, gm {:.2%}, '
                             'gd_ap {:.4f}, gd_an {:.4f}, '
                             'gL {:.4f}'.format(
                                 g_prec_meter.val,
                                 g_m_meter.val,
                                 g_dist_ap_meter.val,
                                 g_dist_an_meter.val,
                                 g_loss_meter.val,
                             ))
                else:
                    g_log = ''

                if cfg.l_loss_weight > 0:
                    l_log = (', lp {:.2%}, lm {:.2%}, '
                             'ld_ap {:.4f}, ld_an {:.4f}, '
                             'lL {:.4f}'.format(
                                 l_prec_meter.val,
                                 l_m_meter.val,
                                 l_dist_ap_meter.val,
                                 l_dist_an_meter.val,
                                 l_loss_meter.val,
                             ))
                else:
                    l_log = ''

                if cfg.id_loss_weight > 0:
                    id_log = (', idL {:.4f}'.format(id_loss_meter.val))
                else:
                    id_log = ''

                if cfg.sift_loss_weight > 0:
                    sift_log = (', sL {:.4f}'.format(sift_loss_meter.val))
                else:
                    sift_log = ''

                if cfg.c_loss_weight > 0:
                    c_log = (', cL {:.4f}'.format(c_loss_meter.val))
                else:
                    c_log = ''
                if cfg.c_loss_weight > 0:
                    d_g_log = (', d_err {:.4f}, g_err {:.4f}'.format(
                        d_err_meter.val, g_err_meter.val))
                else:
                    d_g_log = ''

                total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

                log = time_log + \
                      g_log + l_log + id_log + \
                      d_g_log + sift_log + c_log + \
                      total_loss_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(
            ep + 1,
            time.time() - ep_st,
        )

        if cfg.g_loss_weight > 0:
            g_log = (', gp {:.2%}, gm {:.2%}, '
                     'gd_ap {:.4f}, gd_an {:.4f}, '
                     'gL {:.4f}'.format(
                         g_prec_meter.avg,
                         g_m_meter.avg,
                         g_dist_ap_meter.avg,
                         g_dist_an_meter.avg,
                         g_loss_meter.avg,
                     ))
        else:
            g_log = ''

        if cfg.l_loss_weight > 0:
            l_log = (', lp {:.2%}, lm {:.2%}, '
                     'ld_ap {:.4f}, ld_an {:.4f}, '
                     'lL {:.4f}'.format(
                         l_prec_meter.avg,
                         l_m_meter.avg,
                         l_dist_ap_meter.avg,
                         l_dist_an_meter.avg,
                         l_loss_meter.avg,
                     ))
        else:
            l_log = ''

        if cfg.id_loss_weight > 0:
            id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
        else:
            id_log = ''

        if cfg.sift_loss_weight > 0:
            sift_log = (', sL {:.4f}'.format(sift_loss_meter.avg))
        else:
            sift_log = ''

        if cfg.c_loss_weight > 0:
            c_log = (', cL {:.4f}'.format(c_loss_meter.avg))
        else:
            c_log = ''

        if cfg.dg_loss_weight > 0:
            d_g_log = (', d_err {:.4f}, g_err {:.4f}'.format(
                d_err_meter.avg, g_err_meter.avg))
        else:
            d_g_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

        log = time_log + \
              g_log + l_log + id_log + \
              d_g_log + sift_log + c_log + \
              total_loss_log
        print(log)

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    id_loss=id_loss_meter.avg,
                    d_err=d_err_meter.avg,
                    g_err=g_err_meter.avg,
                    sift_loss=sift_loss_meter.avg,
                    c_loss=c_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_satisfy_margin=g_m_meter.avg,
                    local_satisfy_margin=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
コード例 #3
0
ファイル: train_cen.py プロジェクト: fountaindream/VCFL
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=len(train_set.ids2labels))
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    id_criterion = nn.CrossEntropyLoss()
    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)
    center_loss = CenterLoss(num_classes=len(train_set.ids2labels),
                             feat_dim=2048,
                             use_gpu=True)

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)
    optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.001)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature,
                          use_local_distance=use_local_distance)

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer_centloss, cfg.base_lr, ep + 1,
                          cfg.total_epochs, cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)
            adjust_lr_staircase(optimizer_centloss, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        id_loss_meter = AverageMeter()

        c_loss_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            labels_var = Variable(labels_t)

            feat, global_feat, local_feat, logits = model_w(ims_var)

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
                g_tri_loss,
                global_feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            if cfg.l_loss_weight == 0:
                l_loss = 0
            elif cfg.local_dist_own_hard_sample:
                # Let local distance find its own hard samples.
                l_loss, l_dist_ap, l_dist_an, _ = local_loss(
                    l_tri_loss,
                    local_feat,
                    None,
                    None,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss,
                    local_feat,
                    p_inds,
                    n_inds,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)

            id_loss = 0
            if cfg.id_loss_weight > 0:
                id_loss = id_criterion(logits, labels_var)

            c_loss = 0
            if cfg.c_loss_weight > 0:
                c_loss = center_loss(normalize(global_feat, axis=-1),
                                     labels_var)

            loss = g_loss * cfg.g_loss_weight \
                   + l_loss * cfg.l_loss_weight \
                   + id_loss * cfg.id_loss_weight \
                   + c_loss * cfg.c_loss_weight

            optimizer.zero_grad()
            optimizer_centloss.zero_grad()
            loss.backward()
            optimizer.step()
            for param in center_loss.parameters():
                param.grad.data *= (1 / cfg.c_loss_weight)
            optimizer_centloss.step()

            ############
            # Step Log #
            ############

            # precision
            g_prec = (g_dist_an > g_dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            g_m = (g_dist_an >
                   g_dist_ap + cfg.global_margin).data.float().mean()
            g_d_ap = g_dist_ap.data.mean()
            g_d_an = g_dist_an.data.mean()

            g_prec_meter.update(g_prec)
            g_m_meter.update(g_m)
            g_dist_ap_meter.update(g_d_ap)
            g_dist_an_meter.update(g_d_an)
            g_loss_meter.update(to_scalar(g_loss))

            if cfg.l_loss_weight > 0:
                # precision
                l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                l_m = (l_dist_an >
                       l_dist_ap + cfg.local_margin).data.float().mean()
                l_d_ap = l_dist_ap.data.mean()
                l_d_an = l_dist_an.data.mean()

                l_prec_meter.update(l_prec)
                l_m_meter.update(l_m)
                l_dist_ap_meter.update(l_d_ap)
                l_dist_an_meter.update(l_d_an)
                l_loss_meter.update(to_scalar(l_loss))

            if cfg.id_loss_weight > 0:
                id_loss_meter.update(to_scalar(id_loss))

            if cfg.c_loss_weight > 0:
                c_loss_meter.update(to_scalar(c_loss))

            loss_meter.update(to_scalar(loss))

            if step % cfg.log_steps == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                if cfg.g_loss_weight > 0:
                    g_log = (', gp {:.2%}, gm {:.2%}, '
                             'gd_ap {:.4f}, gd_an {:.4f}, '
                             'gL {:.4f}'.format(
                                 g_prec_meter.val,
                                 g_m_meter.val,
                                 g_dist_ap_meter.val,
                                 g_dist_an_meter.val,
                                 g_loss_meter.val,
                             ))
                else:
                    g_log = ''

                if cfg.l_loss_weight > 0:
                    l_log = (', lp {:.2%}, lm {:.2%}, '
                             'ld_ap {:.4f}, ld_an {:.4f}, '
                             'lL {:.4f}'.format(
                                 l_prec_meter.val,
                                 l_m_meter.val,
                                 l_dist_ap_meter.val,
                                 l_dist_an_meter.val,
                                 l_loss_meter.val,
                             ))
                else:
                    l_log = ''

                if cfg.id_loss_weight > 0:
                    id_log = (', idL {:.4f}'.format(id_loss_meter.val))
                else:
                    id_log = ''

                if cfg.c_loss_weight > 0:
                    c_log = (', cL {:.4f}'.format(c_loss_meter.val))
                else:
                    c_log = ''

                total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

                log = time_log + \
                      g_log + l_log + id_log + \
                      total_loss_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(
            ep + 1,
            time.time() - ep_st,
        )

        if cfg.g_loss_weight > 0:
            g_log = (', gp {:.2%}, gm {:.2%}, '
                     'gd_ap {:.4f}, gd_an {:.4f}, '
                     'gL {:.4f}'.format(
                         g_prec_meter.avg,
                         g_m_meter.avg,
                         g_dist_ap_meter.avg,
                         g_dist_an_meter.avg,
                         g_loss_meter.avg,
                     ))
        else:
            g_log = ''

        if cfg.l_loss_weight > 0:
            l_log = (', lp {:.2%}, lm {:.2%}, '
                     'ld_ap {:.4f}, ld_an {:.4f}, '
                     'lL {:.4f}'.format(
                         l_prec_meter.avg,
                         l_m_meter.avg,
                         l_dist_ap_meter.avg,
                         l_dist_an_meter.avg,
                         l_loss_meter.avg,
                     ))
        else:
            l_log = ''

        if cfg.id_loss_weight > 0:
            id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
        else:
            id_log = ''

        if cfg.c_loss_weight > 0:
            c_log = (', cL {:.4f}'.format(c_loss_meter.avg))
        else:
            c_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

        log = time_log + \
              g_log + l_log + id_log + \
              c_log + total_loss_log
        print(log)

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    id_loss=id_loss_meter.avg,
                    c_loss=c_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_satisfy_margin=g_m_meter.avg,
                    local_satisfy_margin=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)