Ejemplo n.º 1
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)
    num_classes = len(train_set.ids2labels)
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.dataset == 'combined' else create_dataset(
        **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=num_classes)
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    criterion = torch.nn.CrossEntropyLoss()

    # To finetune from ImageNet weights
    finetuned_params = list(model.base.parameters())
    # To train from scratch
    new_params = [
        p for n, p in model.named_parameters() if not n.startswith('base.')
    ]
    param_groups = [{
        'params': finetuned_params,
        'lr': cfg.finetuned_params_lr
    }, {
        'params': new_params,
        'lr': cfg.new_params_lr
    }]
    optimizer = optim.SGD(param_groups,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=True, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n===== Test on validation set =====\n')
        mAP, cmc_scores, _, _ = val_set.eval(normalize_feat=True,
                                             to_re_rank=False,
                                             verbose=True)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        adjust_lr_staircase(optimizer.param_groups,
                            [cfg.finetuned_params_lr, cfg.new_params_lr],
                            ep + 1, cfg.staircase_decay_at_epochs,
                            cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        # For recording loss
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_var = Variable(TVT(torch.from_numpy(labels).long()))

            _, logits_list = model_w(ims_var)
            loss = torch.sum(
                torch.cat(
                    [criterion(logits, labels_var) for logits in logits_list]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                    step, ep + 1,
                    time.time() - step_st, loss_meter.val)
                print(log)

        #############
        # Epoch Log #
        #############

        log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                   time.time() - ep_st,
                                                   loss_meter.avg)
        print(log)

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
Ejemplo n.º 2
0
    def _trainPCB(self):
        cfg = self.cfg
        TVT = self.TVT
        model_w = self.model_w
        criterion = self.criterion
        modules_optims = self.modules_optims
        optimizer = modules_optims[1]
        #train_set = create_dataset(**cfg.train_set_kwargs)
        train_set = self.train_set
        start_ep = 0
        for ep in range(start_ep, cfg.total_epochs):
            # Adjust Learning Rate
            adjust_lr_staircase(
                optimizer.param_groups,
                [cfg.finetuned_params_lr, cfg.new_params_lr],
                ep + 1,
                cfg.staircase_decay_at_epochs,
                cfg.staircase_decay_multiply_factor)

            may_set_mode(modules_optims, 'train')

            # For recording loss
            loss_meter = AverageMeter()

            ep_st = time.time()
            step = 0
            epoch_done = False
            while not epoch_done:
                step += 1
                step_st = time.time()
                ims, im_names, labels, mirrored, epoch_done = \
                    train_set.next_batch()
                ims_var = Variable(TVT(torch.from_numpy(ims).float()))
                labels_var = Variable(TVT(torch.from_numpy(labels).long()))
                _, logits_list = model_w(ims_var)
                loss = torch.sum(
                    torch.cat([criterion(logits, labels_var) for \
                    logits in logits_list]))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # Step log
                loss_meter.update(to_scalar(loss))
                if step % cfg.steps_per_log == 0:
                    log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                        step, ep + 1, time.time() - step_st, loss_meter.val)
                    print(log)

            # Epoch Log
            log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(
                ep + 1, time.time() - ep_st, loss_meter.avg)
            print(log)

            # Test on Validation Set
            mAP, Rank1 = 0, 0
            val_set = self.val_set
            if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
                mAP, Rank1 = self._validatePCB()

            # Save ckpt
            if cfg.log_to_file:
                save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)
Ejemplo n.º 3
0
def main():
  cfg = Config()

  # Redirect logs to both console and file.
  if cfg.log_to_file:
    ReDirectSTD(cfg.stdout_file, 'stdout', False)
    ReDirectSTD(cfg.stderr_file, 'stderr', False)

  # Lazily create SummaryWriter
  writer = None

  TVT, TMO = set_devices(cfg.sys_device_ids)

  if cfg.seed is not None:
    set_seed(cfg.seed)

  # Dump the configurations to log.
  import pprint
  print('-' * 60)
  print('cfg.__dict__')
  pprint.pprint(cfg.__dict__)
  print('-' * 60)

  ###########
  # Dataset #
  ###########

  train_set = create_dataset(**cfg.train_set_kwargs)
  num_classes = len(train_set.ids2labels)
  # The combined dataset does not provide val set currently.
  val_set = None if cfg.dataset == 'combined' else create_dataset(**cfg.val_set_kwargs)

  test_sets = []
  test_set_names = []
  if cfg.dataset == 'combined':
    for name in ['market1501', 'cuhk03', 'duke']:
      cfg.test_set_kwargs['name'] = name
      test_sets.append(create_dataset(**cfg.test_set_kwargs))
      test_set_names.append(name)
  else:
    test_sets.append(create_dataset(**cfg.test_set_kwargs))
    test_set_names.append(cfg.dataset)

  ###########
  # Models  #
  ###########

  model = Model(
    last_conv_stride=cfg.last_conv_stride,
    num_stripes=cfg.num_stripes,
    local_conv_out_channels=cfg.local_conv_out_channels,
    num_classes=num_classes
  )
  # Model wrapper
  model_w = DataParallel(model)

  #############################
  # Criteria and Optimizers   #
  #############################

  criterion = torch.nn.CrossEntropyLoss()

  # To finetune from ImageNet weights
  finetuned_params = list(model.base.parameters())
  # To train from scratch
  new_params = [p for n, p in model.named_parameters()
                if not n.startswith('base.')]
  param_groups = [{'params': finetuned_params, 'lr': cfg.finetuned_params_lr},
                  {'params': new_params, 'lr': cfg.new_params_lr}]
  optimizer = optim.SGD(
    param_groups,
    momentum=cfg.momentum,
    weight_decay=cfg.weight_decay)

  # Bind them together just to save some codes in the following usage.
  modules_optims = [model, optimizer]

  ################################
  # May Resume Models and Optims #
  ################################

  if cfg.resume:
    resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

  # May Transfer Models and Optims to Specified Device. Transferring optimizer
  # is to cope with the case when you load the checkpoint to a new device.
  TMO(modules_optims)

  ########
  # Test #
  ########

  def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    for test_set, name in zip(test_sets, test_set_names):
      test_set.set_feat_func(ExtractFeature(model_w, TVT))
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=True,
        verbose=True)

  def validate():
    if val_set.extract_feat_func is None:
      val_set.set_feat_func(ExtractFeature(model_w, TVT))
    print('\n===== Test on validation set =====\n')
    mAP, cmc_scores, _, _ = val_set.eval(
      normalize_feat=True,
      to_re_rank=False,
      verbose=True)
    print()
    return mAP, cmc_scores[0]

  if cfg.only_test:
    test(load_model_weight=True)
    return

  ############
  # Training #
  ############

  start_ep = resume_ep if cfg.resume else 0
  for ep in range(start_ep, cfg.total_epochs):

    # Adjust Learning Rate
    adjust_lr_staircase(
      optimizer.param_groups,
      [cfg.finetuned_params_lr, cfg.new_params_lr],
      ep + 1,
      cfg.staircase_decay_at_epochs,
      cfg.staircase_decay_multiply_factor)

    may_set_mode(modules_optims, 'train')

    # For recording loss
    loss_meter = AverageMeter()

    ep_st = time.time()
    step = 0
    epoch_done = False
    while not epoch_done:

      step += 1
      step_st = time.time()

      ims, im_names, labels, mirrored, epoch_done = train_set.next_batch()

      ims_var = Variable(TVT(torch.from_numpy(ims).float()))
      labels_var = Variable(TVT(torch.from_numpy(labels).long()))

      _, logits_list = model_w(ims_var)
      loss = torch.sum(
        torch.cat([criterion(logits, labels_var) for logits in logits_list]))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      ############
      # Step Log #
      ############

      loss_meter.update(to_scalar(loss))

      if step % cfg.steps_per_log == 0:
        log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
          step, ep + 1, time.time() - step_st, loss_meter.val)
        print(log)

    #############
    # Epoch Log #
    #############

    log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(
      ep + 1, time.time() - ep_st, loss_meter.avg)
    print(log)

    ##########################
    # Test on Validation Set #
    ##########################

    mAP, Rank1 = 0, 0
    if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
      mAP, Rank1 = validate()

    # Log to TensorBoard

    if cfg.log_to_file:
      if writer is None:
        writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
      writer.add_scalars(
        'val scores',
        dict(mAP=mAP,
             Rank1=Rank1),
        ep)
      writer.add_scalars(
        'loss',
        dict(loss=loss_meter.avg, ),
        ep)

    # save ckpt
    if cfg.log_to_file:
      save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

  ########
  # Test #
  ########

  test(load_model_weight=False)
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    #####################################################
    # Dataset #
    # create train dataset for pcb loss#
    #####################################################
    print('start to create pcb dataset.....')
    train_set = create_dataset(**cfg.pcb_train_set_kwargs)
    print('train_set shape:{}'.format(len(train_set.im_names)))
    num_classes = len(train_set.ids2labels)

    #####################################################
    # Dataset #
    # create val_set test_set for pcb loss#
    #####################################################
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.pcb_dataset == 'combined' else create_dataset(
        **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.pcb_dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.pcb_dataset)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=True, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n===== Test on validation set =====\n')
        mAP, cmc_scores, _, _ = val_set.eval(normalize_feat=True,
                                             to_re_rank=False,
                                             verbose=True)
        print()
        return mAP, cmc_scores[0]

    #############################################################################
    # train pcb model for 60eps #
    #############################################################################

    if not cfg.only_triplet:
        model = Model(last_conv_stride=cfg.last_conv_stride,
                      num_stripes=cfg.num_stripes,
                      num_cols=cfg.num_cols,
                      local_conv_out_channels=cfg.local_conv_out_channels,
                      num_classes=num_classes)

        # Model wrapper
        model_w = DataParallel(model)

        #############################
        # Criteria and Optimizers   #
        #############################

        criterion = torch.nn.CrossEntropyLoss()

        # To finetune from ImageNet weights
        finetuned_params = list(model.base.parameters())
        # To train from scratch
        new_params = [
            p for n, p in model.named_parameters() if not n.startswith('base.')
        ]
        param_groups = [{
            'params': finetuned_params,
            'lr': cfg.finetuned_params_lr
        }, {
            'params': new_params,
            'lr': cfg.new_params_lr
        }]
        optimizer = optim.SGD(param_groups,
                              momentum=cfg.momentum,
                              weight_decay=cfg.weight_decay)

        # Bind them together just to save some codes in the following usage.

        modules_optims = [model, optimizer]

        ################################
        # May Resume Models and Optims #
        ################################

        if cfg.resume:
            resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

        # May Transfer Models and Optims to Specified Device. Transferring optimizer
        # is to cope with the case when you load the checkpoint to a new device.
        TMO(modules_optims)

        if cfg.only_test:
            test(load_model_weight=True)
            return

        print(
            '#####################Begin to train pcb model##############################'
        )
        ############
        # Training #
        ############

        start_ep = resume_ep if cfg.resume else 0
        for ep in range(start_ep, cfg.pcb_epochs):

            # Adjust Learning Rate
            adjust_lr_staircase(optimizer.param_groups,
                                [cfg.finetuned_params_lr, cfg.new_params_lr],
                                ep + 1, cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)
            # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
            may_set_mode(modules_optims, 'train')

            # For recording loss
            loss_meter = AverageMeter()

            ep_st = time.time()
            step = 0
            epoch_done = False
            while not epoch_done:

                step += 1
                step_st = time.time()

                ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
                )

                ims_var = Variable(TVT(torch.from_numpy(ims).float()))
                labels_var = Variable(TVT(torch.from_numpy(labels).long()))

                _, logits_list = model_w(ims_var)
                loss = torch.sum(
                    torch.cat([
                        criterion(logits, labels_var) for logits in logits_list
                    ]))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                del logits_list, ims, im_names, labels

                ############
                # Step Log #
                ############

                loss_meter.update(to_scalar(loss))

                if step % cfg.steps_per_log == 0:
                    log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                        step, ep + 1,
                        time.time() - step_st, loss_meter.val)
                    print(log)

            ################
            # adjust the lr#
            ################
            # scheduler.step(loss_meter.avg)
            #############
            # Epoch Log #
            #############

            log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                       time.time() - ep_st,
                                                       loss_meter.avg)
            print(log)

            ##########################
            # Test on Validation Set #
            ##########################

            mAP, Rank1 = 0, 0
            if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
                mAP, Rank1 = validate()

            # Log to TensorBoard

            if cfg.log_to_file:
                if writer is None:
                    writer = SummaryWriter(
                        log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
                writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1),
                                   ep)
                writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)

            # save ckpt
            if cfg.log_to_file:
                save_ckpt(modules_optims, ep + 1, 0, cfg.pcb_ckpt_file)

        ################
        # Test the pcb #
        ################

        test(load_model_weight=False)

    #####################################################
    # Dataset #
    # create train_set val_set test_set for triplet loss#
    # todo #
    #####################################################
    print('start to create triplet_all dataset.....')
    train_set_anchor = create_dataset_tri(**cfg.train_set_anchor_kwargs)
    print('train_anchor_im_names:{}'.format(
        len(train_set_anchor.get_im_names())))
    train_set_positive = create_dataset_tri(**cfg.train_set_positive_kwargs)
    print('train_positive_im_names:{}'.format(
        len(train_set_positive.get_im_names())))
    train_set_negative = create_dataset_tri(**cfg.train_set_negative_kwargs)
    print('train_negative_im_names:{}'.format(
        len(train_set_negative.get_im_names())))
    num_classes = len(train_set_anchor.ids2labels)
    print('finish creating....num_classes:{}\n '.format(num_classes))

    if not cfg.only_all:

        print('come into triplet*********************************************')

        ##################################################################################
        # Models for triplet 5 eps #
        ###################################################################################
        model = Model(last_conv_stride=cfg.last_conv_stride,
                      num_stripes=cfg.num_stripes,
                      num_cols=cfg.num_cols,
                      local_conv_out_channels=cfg.local_conv_out_channels,
                      num_classes=num_classes)

        #############################
        # Criteria and Optimizers   #
        #############################

        criterion = torch.nn.CrossEntropyLoss()

        #load the checkpoint
        if osp.isfile(cfg.pcb_ckpt_file):
            map_location = (lambda storage, loc: storage)
            sd = torch.load(cfg.pcb_ckpt_file, map_location=map_location)
            model_dict = model.state_dict()
            sd_load = {
                k: v
                for k, v in (sd['state_dicts'][0]).items() if k in model_dict
            }
            model_dict.update(sd_load)
            model.load_state_dict(model_dict)

        # # Optimizer
        #   if hasattr(model.module, 'base'):
        #       base_param_ids = set(map(id, model.base.parameters()))
        #       conv_list_ids = set(map(id,model.local_conv_list.parameters()))
        #       fc_list_ids = set(map(id,model.fc_list.parameters()))
        #       new_params = [p for p in model.parameters() if
        #                     id(p) not in base_param_ids and id(p) not in fc_list_ids]
        #       param_groups = [{'params': new_params, 'lr': 1.0}]
        #   else:
        #       param_groups = model.parameters()

        # To finetune from ImageNet weights
        # finetuned_params = list(model.base.parameters())
        # To train from scratch
        new_params = [
            p for n, p in model.named_parameters()
            if not n.startswith('base.') and not n.startswith('fc_list.')
        ]
        param_groups = [{'params': new_params, 'lr': cfg.new_params_lr}]
        optimizer = optim.SGD(param_groups,
                              momentum=cfg.momentum,
                              weight_decay=cfg.weight_decay)

        # Model wrapper
        model_w = DataParallel(model)
        # Bind them together just to save some codes in the following usage.

        modules_optims = [model, optimizer]

        TMO(modules_optims)

        ############
        # Training #
        ############
        print(
            '#####################Begin to train triplet model##############################'
        )
        start_ep = resume_ep if cfg.resume else 0
        triplet_st = time.time()
        for ep in range(start_ep, cfg.triplet_epochs):

            # Adjust Learning Rate
            adjust_lr_staircase(optimizer.param_groups,
                                [cfg.triplet_finetuned_params_lr], ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.triplet_staircase_decay_multiply_factor)
            # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
            may_set_mode(modules_optims, 'train')

            # For recording loss
            loss_meter = AverageMeter()

            ep_st = time.time()
            step = 0
            epoch_done = False
            while not epoch_done:

                step += 1
                step_st = time.time()

                ims_a, im_names_a, labels_a, mirrored_a, epoch_done = train_set_anchor.next_batch(
                )
                ims_p, im_names_p, labels_p, mirrored_p, epoch_done = train_set_positive.next_batch(
                )
                ims_n, im_names_n, labels_n, mirrored_n, epoch_done = train_set_negative.next_batch(
                )

                ims_var_a,ims_var_p,ims_var_n = Variable(TVT(torch.from_numpy(ims_a).float())),Variable(TVT(torch.from_numpy(ims_p).float())),\
                                              Variable(TVT(torch.from_numpy(ims_n).float()))
                labels_a_var = Variable(TVT(torch.from_numpy(labels_a).long()))
                labels_n_var = Variable(TVT(torch.from_numpy(labels_n).long()))

                local_feat_list_a, logits_list_a = model_w(ims_var_a)
                local_feat_list_p, logits_list_p = model_w(ims_var_p)
                local_feat_list_n, logits_list_n = model_w(ims_var_n)

                loss_triplet = []
                #print('shape of local_feat:{}  {}'.format(len(local_feat_list_a),local_feat_list_a[0].shape))
                # print('Ep{}: '.format(ep+1))
                loss_local = Variable(torch.Tensor([0]))
                for i in range(cfg.parts_num):
                    #print(i)
                    loss_triplet.append(
                        TripletMarginLoss(cfg.margin).forward(
                            local_feat_list_a[i], local_feat_list_p[i],
                            local_feat_list_n[i]))
                # print('the {}th local loss: {}'.format(i,loss_triplet[i]))
                # loss_local = loss_local+loss_triplet[i].data

                if cfg.parts_num == 6:
                    #get the local loss
                    #loss_local_all =0.1* loss_triplet[0]+0.2*loss_triplet[1]+0.2*loss_triplet[2]+0.2*loss_triplet[3]+0.2*loss_triplet[4]+0.1*loss_triplet[5]
                    loss_local_all = loss_triplet[0] + loss_triplet[
                        1] + loss_triplet[2] + loss_triplet[3] + loss_triplet[
                            4] + loss_triplet[5]

                elif cfg.parts_num == 12:
                    loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
                    loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]

                elif cfg.parts_num == 18:
                    loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
                    loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]+\
                    loss_triplet[12]+loss_triplet[13]+loss_triplet[14]+loss_triplet[15]+loss_triplet[16]+loss_triplet[17]

                elif cfg.parts_num == 3:
                    loss_local_all = loss_triplet[0] + loss_triplet[
                        1] + loss_triplet[2]

                loss = torch.div(loss_local_all, cfg.parts_num)
                # loss = torch.div(torch.sum(local for local in loss_triplet),cfg.parts_num)
                print('loss:{}'.format(loss))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # del local_feat_list_a,local_feat_list_p,local_feat_list_n,logits_list_a,logits_list_p,logits_list_n
                ############
                # Step Log #
                ############

                loss_meter.update(to_scalar(loss))

                if step % cfg.steps_per_log == 0:
                    log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                        step, ep + 1,
                        time.time() - step_st, loss_meter.val)
                    print(log)

            ################
            # adjust the lr by torch function#
            ################
            # scheduler.step(loss_meter.avg)
            #############
            # Epoch Log #
            #############

            log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                       time.time() - ep_st,
                                                       loss_meter.avg)
            print(log)

            ##############
            # RPP module #
            ##############

            ##########################
            # Test on Validation Set #
            ##########################

            mAP, Rank1 = 0, 0
            if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
                mAP, Rank1 = validate()

            # Log to TensorBoard

            if cfg.log_to_file:
                if writer is None:
                    writer = SummaryWriter(
                        log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
                writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1),
                                   cfg.pcb_epochs + ep)
                writer.add_scalars('loss', dict(loss=loss_meter.avg, ),
                                   ep + cfg.pcb_epochs)

            # save ckpt
            if cfg.log_to_file:
                save_ckpt(modules_optims, ep + cfg.pcb_epochs + 1, 0,
                          cfg.triplet_ckpt_file)

        triplet_time = time.time() - triplet_st

        ########
        # Test #
        ########

        test(load_model_weight=False)

    print('come into all train********************************************88')
    ##########################################################################
    # train all for 5 eps #
    ###########################################################################
    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  num_cols=cfg.num_cols,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=num_classes)

    #####################################################
    # Dataset #
    # create train dataset for pcb loss#
    #####################################################
    print('start to create all dataset.....')
    all_train_set = create_dataset(**cfg.all_train_set_kwargs)
    print('train_set shape:{}'.format(len(all_train_set.im_names)))
    num_classes = len(train_set.ids2labels)
    #############################
    # Criteria and Optimizers   #
    #############################

    criterion = torch.nn.CrossEntropyLoss()

    #load the checkpoint
    if osp.isfile(cfg.triplet_ckpt_file):
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.triplet_ckpt_file, map_location=map_location)
        model_dict = model.state_dict()
        sd_load = {
            k: v
            for k, v in (sd['state_dicts'][0]).items() if k in model_dict
        }
        model_dict.update(sd_load)
        model.load_state_dict(model_dict)

    # To finetune from ImageNet weights
    finetuned_params = list(model.base.parameters())
    # To train from scratch
    # new_params = [p for n, p in model.named_parameters()
    #               if not n.startswith('base.')]
    new_params = [
        p for n, p in model.named_parameters()
        if not n.startswith('base.') and not n.startswith('local_conv_list.')
    ]
    param_groups = [{
        'params': finetuned_params,
        'lr': cfg.finetuned_params_lr * 0.1
    }, {
        'params': new_params,
        'lr': cfg.finetuned_params_lr
    }]
    optimizer = optim.SGD(param_groups,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay)

    # Model wrapper
    model_w = DataParallel(model)
    # Bind them together just to save some codes in the following usage.

    modules_optims = [model, optimizer]

    TMO(modules_optims)

    ############
    # Training #
    ############
    print(
        '#####################Begin to train all model##############################'
    )
    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        adjust_lr_staircase(optimizer.param_groups, [
            cfg.all_base_finetuned_params_lr, cfg.all_new_finetuned_params_lr
        ], ep + 1, cfg.all_staircase_decay_at_epochs,
                            cfg.staircase_decay_multiply_factor)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
        may_set_mode(modules_optims, 'train')

        # For recording loss
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            # step += 1
            # step_st = time.time()

            # ims_a, im_names_a, labels_a, mirrored_a, epoch_done = train_set_anchor.next_batch()
            # ims_p, im_names_p, labels_p, mirrored_p, epoch_done = train_set_positive.next_batch()
            # ims_n, im_names_n, labels_n, mirrored_n, epoch_done = train_set_negative.next_batch()

            # ims_var_a,ims_var_p,ims_var_n = Variable(TVT(torch.from_numpy(ims_a).float())),Variable(TVT(torch.from_numpy(ims_p).float())),\
            #                               Variable(TVT(torch.from_numpy(ims_n).float()))
            # labels_a_var = Variable(TVT(torch.from_numpy(labels_a).long()))
            # labels_n_var = Variable(TVT(torch.from_numpy(labels_n).long()))

            # local_feat_list_a, logits_list_a = model_w(ims_var_a)
            # local_feat_list_p, logits_list_p = model_w(ims_var_p)
            # local_feat_list_n, logits_list_n = model_w(ims_var_n)

            # loss_triplet = []
            # #print('shape of local_feat:{}  {}'.format(len(local_feat_list_a),local_feat_list_a[0].shape))
            # # print('Ep{}: '.format(ep+1))
            # loss_local = Variable(torch.Tensor([0]))

            # for i in range(cfg.parts_num):
            #   #print(i)
            #   loss_triplet.append(TripletMarginLoss(cfg.margin).forward(local_feat_list_a[i], local_feat_list_p[i], local_feat_list_n[i]))
            #  # print('the {}th local loss: {}'.format(i,loss_triplet[i]))
            #   # loss_local = loss_local+loss_triplet[i].data

            # if cfg.parts_num == 6:
            #     #get the local loss
            #     #loss_local_all =0.1* loss_triplet[0]+0.2*loss_triplet[1]+0.2*loss_triplet[2]+0.2*loss_triplet[3]+0.2*loss_triplet[4]+0.1*loss_triplet[5]
            #     loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]

            # elif cfg.parts_num == 12:
            #     loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
            #     loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]

            # elif cfg.parts_num == 18:
            #     loss_local_all =loss_triplet[0]+loss_triplet[1]+loss_triplet[2]+loss_triplet[3]+loss_triplet[4]+loss_triplet[5]+\
            #     loss_triplet[6]+loss_triplet[7]+loss_triplet[8]+loss_triplet[9]+loss_triplet[10]+loss_triplet[11]+\
            #     loss_triplet[12]+loss_triplet[13]+loss_triplet[14]+loss_triplet[15]+loss_triplet[16]+loss_triplet[17]

            # loss_local = torch.div(loss_local_all,cfg.parts_num)
            # # print('for loss_local:{}'.format(loss_local))
            # #get the id loss
            # loss_id_a = torch.sum(
            #   torch.cat([criterion(logits, labels_a_var) for logits in logits_list_a]))
            # loss_id_p = torch.sum(
            #   torch.cat([criterion(logits, labels_a_var) for logits in logits_list_p]))
            # loss_id_n = torch.sum(
            #   torch.cat([criterion(logits, labels_n_var) for logits in logits_list_n]))
            # """
            #   get the id using the whole feature
            # """
            # # loss_id_a = criterion(logits_list_a, labels_a_var)
            # # loss_id_p = criterion(logits_list_a, labels_a_var)
            # # loss_id_n = criterion(logits_list_n, labels_n_var)

            # loss_id = torch.div(loss_id_a+loss_id_n+loss_id_p,3)
            # # print('for id loss:{}'.format(loss_id))

            # loss = loss_id+5*loss_local
            # print('loss:{}'.format(loss))
            # optimizer.zero_grad()
            # loss.backward()
            # optimizer.step()

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = all_train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_var = Variable(TVT(torch.from_numpy(labels).long()))

            _, logits_list = model_w(ims_var)
            loss = torch.sum(
                torch.cat(
                    [criterion(logits, labels_var) for logits in logits_list]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                    step, ep + 1,
                    time.time() - step_st, loss_meter.val)
                print(log)

        ################
        # adjust the lr by torch function#
        ################
        # scheduler.step(loss_meter.avg)
        #############
        # Epoch Log #
        #############

        log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                   time.time() - ep_st,
                                                   loss_meter.avg)
        print(log)

        ##############
        # RPP module #
        ##############

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1),
                               cfg.pcb_epochs + cfg.triplet_epochs + ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ),
                               ep + cfg.pcb_epochs + cfg.triplet_epochs)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims,
                      ep + cfg.pcb_epochs + 1 + cfg.triplet_epochs, 0,
                      cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
    print('over**************************************************')
    print('{} spends {} s'.format(cfg.triplet_dataset, triplet_time))