Beispiel #1
0
    def __init__(self, args, loop, num_classes_dict):
        gpu_ids = args['gpus']
        only_test = args['only_test']
        feat_dims_per_part = args['stripe_feats_dims']
        # Parameters
        self.num_classes_dict = num_classes_dict
        self.dataset_nm = args['dataset']
        self.update_trainingset_path = args['update_trainingset_path']
        root_save_dir = args['root_save_dir']
        save_dir = osp.join(root_save_dir, self.dataset_nm)
        self.save_dir = save_dir
        self.loop = loop
        self.num_clusters = int(args['num_clusters'][self.dataset_nm])
        # For uReID
        self.simlarity_thresh = float(args['lambda'])
        cfg = Config(gpu_ids, self.dataset_nm, feat_dims_per_part, \
            only_test, loop, save_dir)
        
        # Redirect logs to both console and file.
        if cfg.log_to_file:
            ReDirectSTD(cfg.stdout_file, 'stdout', False)
            ReDirectSTD(cfg.stderr_file, 'stderr', False)
        
        # Dump the configurations to log.
        self.cfg = cfg
        import pprint
        print('-' * 60)
        print('cfg.__dict__')
        pprint.pprint(cfg.__dict__)
        print('-' * 60)

        # Load data
        self._loadDataset()
        print('INFO: 1. Datasets are loaded ...')

        # Updata number of classes
        self.num_classes_dict[self.dataset_nm] = len(self.train_set.ids2labels)

        # Initialize Model
        self._initModel()
        print('INFO: 2. Model is initialized ...')
 
        print('INFO: ==== Project initialize successfully!')
Beispiel #2
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)
    num_classes = len(train_set.ids2labels)
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.dataset == 'combined' else create_dataset(
        **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=num_classes)
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    criterion = torch.nn.CrossEntropyLoss()

    # To finetune from ImageNet weights
    finetuned_params = list(model.base.parameters())
    # To train from scratch
    new_params = [
        p for n, p in model.named_parameters() if not n.startswith('base.')
    ]
    param_groups = [{
        'params': finetuned_params,
        'lr': cfg.finetuned_params_lr
    }, {
        'params': new_params,
        'lr': cfg.new_params_lr
    }]
    optimizer = optim.SGD(param_groups,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=True, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n===== Test on validation set =====\n')
        mAP, cmc_scores, _, _ = val_set.eval(normalize_feat=True,
                                             to_re_rank=False,
                                             verbose=True)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        adjust_lr_staircase(optimizer.param_groups,
                            [cfg.finetuned_params_lr, cfg.new_params_lr],
                            ep + 1, cfg.staircase_decay_at_epochs,
                            cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        # For recording loss
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_var = Variable(TVT(torch.from_numpy(labels).long()))

            _, logits_list = model_w(ims_var)
            loss = torch.sum(
                torch.cat(
                    [criterion(logits, labels_var) for logits in logits_list]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                    step, ep + 1,
                    time.time() - step_st, loss_meter.val)
                print(log)

        #############
        # Epoch Log #
        #############

        log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                   time.time() - ep_st,
                                                   loss_meter.avg)
        print(log)

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    TVT, TMO = set_devices(cfg.sys_device_ids)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    test_set = create_dataset(**cfg.test_set_kwargs)

    #########
    # Model #
    #########

    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=0)
    # Model wrapper
    model_w = DataParallel(model)

    # May Transfer Model to Specified Device.
    TMO([model])

    #####################
    # Load Model Weight #
    #####################

    # To first load weights to CPU
    map_location = (lambda storage, loc: storage)
    used_file = cfg.model_weight_file or cfg.ckpt_file
    loaded = torch.load(used_file, map_location=map_location)
    if cfg.model_weight_file == '':
        loaded = loaded['state_dicts'][0]
    load_state_dict(model, loaded)
    print('Loaded model weights from {}'.format(used_file))

    ###################
    # Extract Feature #
    ###################

    test_set.set_feat_func(ExtractFeature(model_w, TVT))

    with measure_time('Extracting feature...', verbose=True):
        feat, ids, cams, im_names, marks = test_set.extract_feat(True,
                                                                 verbose=True)

    #######################
    # Select Query Images #
    #######################

    # Fix some query images, so that the visualization for different models can
    # be compared.

    # Sort in the order of image names
    inds = np.argsort(im_names)
    feat, ids, cams, im_names, marks = \
      feat[inds], ids[inds], cams[inds], im_names[inds], marks[inds]

    # query, gallery index mask
    is_q = marks == 0
    is_g = marks == 1

    prng = np.random.RandomState(1)
    # selected query indices
    sel_q_inds = prng.permutation(range(np.sum(is_q)))[:cfg.num_queries]

    q_ids = ids[is_q][sel_q_inds]
    q_cams = cams[is_q][sel_q_inds]
    q_feat = feat[is_q][sel_q_inds]
    q_im_names = im_names[is_q][sel_q_inds]

    ####################
    # Compute Distance #
    ####################

    # query-gallery distance
    q_g_dist = compute_dist(q_feat, feat[is_g], type='euclidean')

    ###########################
    # Save Rank List as Image #
    ###########################

    q_im_paths = [ospj(test_set.im_dir, n) for n in q_im_names]
    save_paths = [ospj(cfg.exp_dir, 'rank_lists', n) for n in q_im_names]
    g_im_paths = [ospj(test_set.im_dir, n) for n in im_names[is_g]]

    for dist_vec, q_id, q_cam, q_im_path, save_path in zip(
            q_g_dist, q_ids, q_cams, q_im_paths, save_paths):

        rank_list, same_id = get_rank_list(dist_vec, q_id, q_cam, ids[is_g],
                                           cams[is_g], cfg.rank_list_size)

        save_rank_list_to_im(rank_list, same_id, q_im_path, g_im_paths,
                             save_path)
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    #####################################################
    # Dataset #
    # create train dataset for pcb loss#
    #####################################################
    print('start to create pcb dataset.....')
    train_set = create_dataset(**cfg.pcb_train_set_kwargs)
    print('train_set shape:{}'.format(len(train_set.im_names)))
    num_classes = len(train_set.ids2labels)

    #####################################################
    # Dataset #
    # create val_set test_set for pcb loss#
    #####################################################
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.pcb_dataset == 'combined' else create_dataset(
        **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.pcb_dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.pcb_dataset)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=True, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n===== Test on validation set =====\n')
        mAP, cmc_scores, _, _ = val_set.eval(normalize_feat=True,
                                             to_re_rank=False,
                                             verbose=True)
        print()
        return mAP, cmc_scores[0]

    #############################################################################
    # train pcb model for 60eps #
    #############################################################################

    if not cfg.only_triplet:
        model = Model(last_conv_stride=cfg.last_conv_stride,
                      num_stripes=cfg.num_stripes,
                      num_cols=cfg.num_cols,
                      local_conv_out_channels=cfg.local_conv_out_channels,
                      num_classes=num_classes)

        # Model wrapper
        model_w = DataParallel(model)

        #############################
        # Criteria and Optimizers   #
        #############################

        criterion = torch.nn.CrossEntropyLoss()

        # To finetune from ImageNet weights
        finetuned_params = list(model.base.parameters())
        # To train from scratch
        new_params = [
            p for n, p in model.named_parameters() if not n.startswith('base.')
        ]
        param_groups = [{
            'params': finetuned_params,
            'lr': cfg.finetuned_params_lr
        }, {
            'params': new_params,
            'lr': cfg.new_params_lr
        }]
        optimizer = optim.SGD(param_groups,
                              momentum=cfg.momentum,
                              weight_decay=cfg.weight_decay)

        # Bind them together just to save some codes in the following usage.

        modules_optims = [model, optimizer]

        ################################
        # May Resume Models and Optims #
        ################################

        if cfg.resume:
            resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

        # May Transfer Models and Optims to Specified Device. Transferring optimizer
        # is to cope with the case when you load the checkpoint to a new device.
        TMO(modules_optims)

        if cfg.only_test:
            test(load_model_weight=True)
            return

        print(
            '#####################Begin to train pcb model##############################'
        )
        ############
        # Training #
        ############

        start_ep = resume_ep if cfg.resume else 0
        for ep in range(start_ep, cfg.pcb_epochs):

            # Adjust Learning Rate
            adjust_lr_staircase(optimizer.param_groups,
                                [cfg.finetuned_params_lr, cfg.new_params_lr],
                                ep + 1, cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)
            # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
            may_set_mode(modules_optims, 'train')

            # For recording loss
            loss_meter = AverageMeter()

            ep_st = time.time()
            step = 0
            epoch_done = False
            while not epoch_done:

                step += 1
                step_st = time.time()

                ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
                )

                ims_var = Variable(TVT(torch.from_numpy(ims).float()))
                labels_var = Variable(TVT(torch.from_numpy(labels).long()))

                _, logits_list = model_w(ims_var)
                loss = torch.sum(
                    torch.cat([
                        criterion(logits, labels_var) for logits in logits_list
                    ]))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                del logits_list, ims, im_names, labels

                ############
                # Step Log #
                ############

                loss_meter.update(to_scalar(loss))

                if step % cfg.steps_per_log == 0:
                    log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                        step, ep + 1,
                        time.time() - step_st, loss_meter.val)
                    print(log)

            ################
            # adjust the lr#
            ################
            # scheduler.step(loss_meter.avg)
            #############
            # Epoch Log #
            #############

            log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                       time.time() - ep_st,
                                                       loss_meter.avg)
            print(log)

            ##########################
            # Test on Validation Set #
            ##########################

            mAP, Rank1 = 0, 0
            if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
                mAP, Rank1 = validate()

            # Log to TensorBoard

            if cfg.log_to_file:
                if writer is None:
                    writer = SummaryWriter(
                        log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
                writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1),
                                   ep)
                writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)

            # save ckpt
            if cfg.log_to_file:
                save_ckpt(modules_optims, ep + 1, 0, cfg.pcb_ckpt_file)

        ################
        # Test the pcb #
        ################

        test(load_model_weight=False)

    #####################################################
    # Dataset #
    # create train_set val_set test_set for triplet loss#
    # todo #
    #####################################################
    print('start to create triplet_all dataset.....')
    train_set_anchor = create_dataset_tri(**cfg.train_set_anchor_kwargs)
    print('train_anchor_im_names:{}'.format(
        len(train_set_anchor.get_im_names())))
    train_set_positive = create_dataset_tri(**cfg.train_set_positive_kwargs)
    print('train_positive_im_names:{}'.format(
        len(train_set_positive.get_im_names())))
    train_set_negative = create_dataset_tri(**cfg.train_set_negative_kwargs)
    print('train_negative_im_names:{}'.format(
        len(train_set_negative.get_im_names())))
    num_classes = len(train_set_anchor.ids2labels)
    print('finish creating....num_classes:{}\n '.format(num_classes))

    if not cfg.only_all:

        print('come into triplet*********************************************')

        ##################################################################################
        # Models for triplet 5 eps #
        ###################################################################################
        model = Model(last_conv_stride=cfg.last_conv_stride,
                      num_stripes=cfg.num_stripes,
                      num_cols=cfg.num_cols,
                      local_conv_out_channels=cfg.local_conv_out_channels,
                      num_classes=num_classes)

        #############################
        # Criteria and Optimizers   #
        #############################

        criterion = torch.nn.CrossEntropyLoss()

        #load the checkpoint
        if osp.isfile(cfg.pcb_ckpt_file):
            map_location = (lambda storage, loc: storage)
            sd = torch.load(cfg.pcb_ckpt_file, map_location=map_location)
            model_dict = model.state_dict()
            sd_load = {
                k: v
                for k, v in (sd['state_dicts'][0]).items() if k in model_dict
            }
            model_dict.update(sd_load)
            model.load_state_dict(model_dict)

        # # Optimizer
        #   if hasattr(model.module, 'base'):
        #       base_param_ids = set(map(id, model.base.parameters()))
        #       conv_list_ids = set(map(id,model.local_conv_list.parameters()))
        #       fc_list_ids = set(map(id,model.fc_list.parameters()))
        #       new_params = [p for p in model.parameters() if
        #                     id(p) not in base_param_ids and id(p) not in fc_list_ids]
        #       param_groups = [{'params': new_params, 'lr': 1.0}]
        #   else:
        #       param_groups = model.parameters()

        # To finetune from ImageNet weights
        # finetuned_params = list(model.base.parameters())
        # To train from scratch
        new_params = [
            p for n, p in model.named_parameters()
            if not n.startswith('base.') and not n.startswith('fc_list.')
        ]
        param_groups = [{'params': new_params, 'lr': cfg.new_params_lr}]
        optimizer = optim.SGD(param_groups,
                              momentum=cfg.momentum,
                              weight_decay=cfg.weight_decay)

        # Model wrapper
        model_w = DataParallel(model)
        # Bind them together just to save some codes in the following usage.

        modules_optims = [model, optimizer]

        TMO(modules_optims)

        ############
        # Training #
        ############
        print(
            '#####################Begin to train triplet model##############################'
        )
        start_ep = resume_ep if cfg.resume else 0
        triplet_st = time.time()
        for ep in range(start_ep, cfg.triplet_epochs):

            # Adjust Learning Rate
            adjust_lr_staircase(optimizer.param_groups,
                                [cfg.triplet_finetuned_params_lr], ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.triplet_staircase_decay_multiply_factor)
            # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
            may_set_mode(modules_optims, 'train')

            # For recording loss
            loss_meter = AverageMeter()

            ep_st = time.time()
            step = 0
            epoch_done = False
            while not epoch_done:

                step += 1
                step_st = time.time()

                ims_a, im_names_a, labels_a, mirrored_a, epoch_done = train_set_anchor.next_batch(
                )
                ims_p, im_names_p, labels_p, mirrored_p, epoch_done = train_set_positive.next_batch(
                )
                ims_n, im_names_n, labels_n, mirrored_n, epoch_done = train_set_negative.next_batch(
                )

                ims_var_a,ims_var_p,ims_var_n = Variable(TVT(torch.from_numpy(ims_a).float())),Variable(TVT(torch.from_numpy(ims_p).float())),\
                                              Variable(TVT(torch.from_numpy(ims_n).float()))
                labels_a_var = Variable(TVT(torch.from_numpy(labels_a).long()))
                labels_n_var = Variable(TVT(torch.from_numpy(labels_n).long()))

                local_feat_list_a, logits_list_a = model_w(ims_var_a)
                local_feat_list_p, logits_list_p = model_w(ims_var_p)
                local_feat_list_n, logits_list_n = model_w(ims_var_n)

                loss_triplet = []
                #print('shape of local_feat:{}  {}'.format(len(local_feat_list_a),local_feat_list_a[0].shape))
                # print('Ep{}: '.format(ep+1))
                loss_local = Variable(torch.Tensor([0]))
                for i in range(cfg.parts_num):
                    #print(i)
                    loss_triplet.append(
                        TripletMarginLoss(cfg.margin).forward(
                            local_feat_list_a[i], local_feat_list_p[i],
                            local_feat_list_n[i]))
                # print('the {}th local loss: {}'.format(i,loss_triplet[i]))
                # loss_local = loss_local+loss_triplet[i].data

                if cfg.parts_num == 6:
                    #get the local loss
                    #loss_local_all =0.1* loss_triplet[0]+0.2*loss_triplet[1]+0.2*loss_triplet[2]+0.2*loss_triplet[3]+0.2*loss_triplet[4]+0.1*loss_triplet[5]
                    loss_local_all = loss_triplet[0] + loss_triplet[
                        1] + loss_triplet[2] + loss_triplet[3] + loss_triplet[
                            4] + loss_triplet[5]

                elif cfg.parts_num == 12:
                    loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
                    loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]

                elif cfg.parts_num == 18:
                    loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
                    loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]+\
                    loss_triplet[12]+loss_triplet[13]+loss_triplet[14]+loss_triplet[15]+loss_triplet[16]+loss_triplet[17]

                elif cfg.parts_num == 3:
                    loss_local_all = loss_triplet[0] + loss_triplet[
                        1] + loss_triplet[2]

                loss = torch.div(loss_local_all, cfg.parts_num)
                # loss = torch.div(torch.sum(local for local in loss_triplet),cfg.parts_num)
                print('loss:{}'.format(loss))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # del local_feat_list_a,local_feat_list_p,local_feat_list_n,logits_list_a,logits_list_p,logits_list_n
                ############
                # Step Log #
                ############

                loss_meter.update(to_scalar(loss))

                if step % cfg.steps_per_log == 0:
                    log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                        step, ep + 1,
                        time.time() - step_st, loss_meter.val)
                    print(log)

            ################
            # adjust the lr by torch function#
            ################
            # scheduler.step(loss_meter.avg)
            #############
            # Epoch Log #
            #############

            log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                       time.time() - ep_st,
                                                       loss_meter.avg)
            print(log)

            ##############
            # RPP module #
            ##############

            ##########################
            # Test on Validation Set #
            ##########################

            mAP, Rank1 = 0, 0
            if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
                mAP, Rank1 = validate()

            # Log to TensorBoard

            if cfg.log_to_file:
                if writer is None:
                    writer = SummaryWriter(
                        log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
                writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1),
                                   cfg.pcb_epochs + ep)
                writer.add_scalars('loss', dict(loss=loss_meter.avg, ),
                                   ep + cfg.pcb_epochs)

            # save ckpt
            if cfg.log_to_file:
                save_ckpt(modules_optims, ep + cfg.pcb_epochs + 1, 0,
                          cfg.triplet_ckpt_file)

        triplet_time = time.time() - triplet_st

        ########
        # Test #
        ########

        test(load_model_weight=False)

    print('come into all train********************************************88')
    ##########################################################################
    # train all for 5 eps #
    ###########################################################################
    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  num_cols=cfg.num_cols,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=num_classes)

    #####################################################
    # Dataset #
    # create train dataset for pcb loss#
    #####################################################
    print('start to create all dataset.....')
    all_train_set = create_dataset(**cfg.all_train_set_kwargs)
    print('train_set shape:{}'.format(len(all_train_set.im_names)))
    num_classes = len(train_set.ids2labels)
    #############################
    # Criteria and Optimizers   #
    #############################

    criterion = torch.nn.CrossEntropyLoss()

    #load the checkpoint
    if osp.isfile(cfg.triplet_ckpt_file):
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.triplet_ckpt_file, map_location=map_location)
        model_dict = model.state_dict()
        sd_load = {
            k: v
            for k, v in (sd['state_dicts'][0]).items() if k in model_dict
        }
        model_dict.update(sd_load)
        model.load_state_dict(model_dict)

    # To finetune from ImageNet weights
    finetuned_params = list(model.base.parameters())
    # To train from scratch
    # new_params = [p for n, p in model.named_parameters()
    #               if not n.startswith('base.')]
    new_params = [
        p for n, p in model.named_parameters()
        if not n.startswith('base.') and not n.startswith('local_conv_list.')
    ]
    param_groups = [{
        'params': finetuned_params,
        'lr': cfg.finetuned_params_lr * 0.1
    }, {
        'params': new_params,
        'lr': cfg.finetuned_params_lr
    }]
    optimizer = optim.SGD(param_groups,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay)

    # Model wrapper
    model_w = DataParallel(model)
    # Bind them together just to save some codes in the following usage.

    modules_optims = [model, optimizer]

    TMO(modules_optims)

    ############
    # Training #
    ############
    print(
        '#####################Begin to train all model##############################'
    )
    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        adjust_lr_staircase(optimizer.param_groups, [
            cfg.all_base_finetuned_params_lr, cfg.all_new_finetuned_params_lr
        ], ep + 1, cfg.all_staircase_decay_at_epochs,
                            cfg.staircase_decay_multiply_factor)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
        may_set_mode(modules_optims, 'train')

        # For recording loss
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            # step += 1
            # step_st = time.time()

            # ims_a, im_names_a, labels_a, mirrored_a, epoch_done = train_set_anchor.next_batch()
            # ims_p, im_names_p, labels_p, mirrored_p, epoch_done = train_set_positive.next_batch()
            # ims_n, im_names_n, labels_n, mirrored_n, epoch_done = train_set_negative.next_batch()

            # ims_var_a,ims_var_p,ims_var_n = Variable(TVT(torch.from_numpy(ims_a).float())),Variable(TVT(torch.from_numpy(ims_p).float())),\
            #                               Variable(TVT(torch.from_numpy(ims_n).float()))
            # labels_a_var = Variable(TVT(torch.from_numpy(labels_a).long()))
            # labels_n_var = Variable(TVT(torch.from_numpy(labels_n).long()))

            # local_feat_list_a, logits_list_a = model_w(ims_var_a)
            # local_feat_list_p, logits_list_p = model_w(ims_var_p)
            # local_feat_list_n, logits_list_n = model_w(ims_var_n)

            # loss_triplet = []
            # #print('shape of local_feat:{}  {}'.format(len(local_feat_list_a),local_feat_list_a[0].shape))
            # # print('Ep{}: '.format(ep+1))
            # loss_local = Variable(torch.Tensor([0]))

            # for i in range(cfg.parts_num):
            #   #print(i)
            #   loss_triplet.append(TripletMarginLoss(cfg.margin).forward(local_feat_list_a[i], local_feat_list_p[i], local_feat_list_n[i]))
            #  # print('the {}th local loss: {}'.format(i,loss_triplet[i]))
            #   # loss_local = loss_local+loss_triplet[i].data

            # if cfg.parts_num == 6:
            #     #get the local loss
            #     #loss_local_all =0.1* loss_triplet[0]+0.2*loss_triplet[1]+0.2*loss_triplet[2]+0.2*loss_triplet[3]+0.2*loss_triplet[4]+0.1*loss_triplet[5]
            #     loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]

            # elif cfg.parts_num == 12:
            #     loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
            #     loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]

            # elif cfg.parts_num == 18:
            #     loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
            #     loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]+\
            #     loss_triplet[12]+loss_triplet[13]+loss_triplet[14]+loss_triplet[15]+loss_triplet[16]+loss_triplet[17]

            # loss_local = torch.div(loss_local_all,cfg.parts_num)
            # # print('for loss_local:{}'.format(loss_local))
            # #get the id loss
            # loss_id_a = torch.sum(
            #   torch.cat([criterion(logits, labels_a_var) for logits in logits_list_a]))
            # loss_id_p = torch.sum(
            #   torch.cat([criterion(logits, labels_a_var) for logits in logits_list_p]))
            # loss_id_n = torch.sum(
            #   torch.cat([criterion(logits, labels_n_var) for logits in logits_list_n]))
            # """
            #   get the id using the whole feature
            # """
            # # loss_id_a = criterion(logits_list_a, labels_a_var)
            # # loss_id_p = criterion(logits_list_a, labels_a_var)
            # # loss_id_n = criterion(logits_list_n, labels_n_var)

            # loss_id = torch.div(loss_id_a+loss_id_n+loss_id_p,3)
            # # print('for id loss:{}'.format(loss_id))

            # loss = loss_id+5*loss_local
            # print('loss:{}'.format(loss))
            # optimizer.zero_grad()
            # loss.backward()
            # optimizer.step()

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = all_train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_var = Variable(TVT(torch.from_numpy(labels).long()))

            _, logits_list = model_w(ims_var)
            loss = torch.sum(
                torch.cat(
                    [criterion(logits, labels_var) for logits in logits_list]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                    step, ep + 1,
                    time.time() - step_st, loss_meter.val)
                print(log)

        ################
        # adjust the lr by torch function#
        ################
        # scheduler.step(loss_meter.avg)
        #############
        # Epoch Log #
        #############

        log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                   time.time() - ep_st,
                                                   loss_meter.avg)
        print(log)

        ##############
        # RPP module #
        ##############

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1),
                               cfg.pcb_epochs + cfg.triplet_epochs + ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ),
                               ep + cfg.pcb_epochs + cfg.triplet_epochs)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims,
                      ep + cfg.pcb_epochs + 1 + cfg.triplet_epochs, 0,
                      cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
    print('over**************************************************')
    print('{} spends {} s'.format(cfg.triplet_dataset, triplet_time))