示例#1
0
def start():

    ###############
    # preparation #
    ###############

    model = Model(last_conv_stride=last_conv_stride,
                  num_stripes=num_stripes,
                  local_conv_out_channels=local_conv_out_channels)
    model_w = DataParallel(model)
    TMO([model])

    # preprocessing
    preprocessor = PreProcessIm(resize_h_w=resize_h_w,
                                scale=scale_im,
                                im_mean=im_mean,
                                im_std=im_std)

    # load model
    map_location = (lambda storage, loc: storage)
    sd = torch.load(model_weight_file, map_location=map_location)
    load_state_dict(model, sd['state_dicts'][0])
    print('Loaded model weight from {}'.format(model_weight_file))

    extractor = ExtractFeature(model_w, TVT)
    return preprocessor, extractor
示例#2
0
    def _initModel(self):
        # Current dataset name:
        dname = self.dataset_nm
        # Number of classes
        num_classes = self.num_classes_dict[dname]
        # Parameters
        cfg = self.cfg
        # ReID model
        TVT, TMO = set_devices(cfg.sys_device_ids)
        if cfg.seed is not None:
            set_seed(cfg.seed)
        model = Model(
            last_conv_stride=cfg.last_conv_stride,
            num_stripes=cfg.num_stripes,
            local_conv_out_channels=cfg.local_conv_out_channels,
            num_classes=num_classes
        )
        # Model wrapper
        model_w = DataParallel(model)

        # Criteria and Optimizers
        criterion = torch.nn.CrossEntropyLoss()
        # To finetune from ImageNet weights
        finetuned_params = list(model.base.parameters())
        # To train from scratch
        new_params = [p for n, p in model.named_parameters()
            if not n.startswith('base.')]
        param_groups = [{'params': finetuned_params, \
            'lr': cfg.finetuned_params_lr}, \
            {'params': new_params, 'lr': cfg.new_params_lr}]
        optimizer = optim.SGD(
            param_groups,
            momentum=cfg.momentum,
            weight_decay=cfg.weight_decay)
        
        # Bind them together just to save some codes in the following usage.
        modules_optims = [model, optimizer]

        # May Transfer Models and Optims to Specified Device. 
        # Transferring optimizer is to cope with the case when 
        # you load the checkpoint to a new device.
        TMO(modules_optims)

        self.model_w = model_w
        self.modules_optims = modules_optims
        self.criterion = criterion
        self.TVT = TVT
示例#3
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)
    num_classes = len(train_set.ids2labels)
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.dataset == 'combined' else create_dataset(
        **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=num_classes)
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    criterion = torch.nn.CrossEntropyLoss()

    # To finetune from ImageNet weights
    finetuned_params = list(model.base.parameters())
    # To train from scratch
    new_params = [
        p for n, p in model.named_parameters() if not n.startswith('base.')
    ]
    param_groups = [{
        'params': finetuned_params,
        'lr': cfg.finetuned_params_lr
    }, {
        'params': new_params,
        'lr': cfg.new_params_lr
    }]
    optimizer = optim.SGD(param_groups,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=True, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n===== Test on validation set =====\n')
        mAP, cmc_scores, _, _ = val_set.eval(normalize_feat=True,
                                             to_re_rank=False,
                                             verbose=True)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        adjust_lr_staircase(optimizer.param_groups,
                            [cfg.finetuned_params_lr, cfg.new_params_lr],
                            ep + 1, cfg.staircase_decay_at_epochs,
                            cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        # For recording loss
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_var = Variable(TVT(torch.from_numpy(labels).long()))

            _, logits_list = model_w(ims_var)
            loss = torch.sum(
                torch.cat(
                    [criterion(logits, labels_var) for logits in logits_list]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                    step, ep + 1,
                    time.time() - step_st, loss_meter.val)
                print(log)

        #############
        # Epoch Log #
        #############

        log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                   time.time() - ep_st,
                                                   loss_meter.avg)
        print(log)

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    TVT, TMO = set_devices(cfg.sys_device_ids)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    test_set = create_dataset(**cfg.test_set_kwargs)

    #########
    # Model #
    #########

    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=0)
    # Model wrapper
    model_w = DataParallel(model)

    # May Transfer Model to Specified Device.
    TMO([model])

    #####################
    # Load Model Weight #
    #####################

    # To first load weights to CPU
    map_location = (lambda storage, loc: storage)
    used_file = cfg.model_weight_file or cfg.ckpt_file
    loaded = torch.load(used_file, map_location=map_location)
    if cfg.model_weight_file == '':
        loaded = loaded['state_dicts'][0]
    load_state_dict(model, loaded)
    print('Loaded model weights from {}'.format(used_file))

    ###################
    # Extract Feature #
    ###################

    test_set.set_feat_func(ExtractFeature(model_w, TVT))

    with measure_time('Extracting feature...', verbose=True):
        feat, ids, cams, im_names, marks = test_set.extract_feat(True,
                                                                 verbose=True)

    #######################
    # Select Query Images #
    #######################

    # Fix some query images, so that the visualization for different models can
    # be compared.

    # Sort in the order of image names
    inds = np.argsort(im_names)
    feat, ids, cams, im_names, marks = \
      feat[inds], ids[inds], cams[inds], im_names[inds], marks[inds]

    # query, gallery index mask
    is_q = marks == 0
    is_g = marks == 1

    prng = np.random.RandomState(1)
    # selected query indices
    sel_q_inds = prng.permutation(range(np.sum(is_q)))[:cfg.num_queries]

    q_ids = ids[is_q][sel_q_inds]
    q_cams = cams[is_q][sel_q_inds]
    q_feat = feat[is_q][sel_q_inds]
    q_im_names = im_names[is_q][sel_q_inds]

    ####################
    # Compute Distance #
    ####################

    # query-gallery distance
    q_g_dist = compute_dist(q_feat, feat[is_g], type='euclidean')

    ###########################
    # Save Rank List as Image #
    ###########################

    q_im_paths = [ospj(test_set.im_dir, n) for n in q_im_names]
    save_paths = [ospj(cfg.exp_dir, 'rank_lists', n) for n in q_im_names]
    g_im_paths = [ospj(test_set.im_dir, n) for n in im_names[is_g]]

    for dist_vec, q_id, q_cam, q_im_path, save_path in zip(
            q_g_dist, q_ids, q_cams, q_im_paths, save_paths):

        rank_list, same_id = get_rank_list(dist_vec, q_id, q_cam, ids[is_g],
                                           cams[is_g], cfg.rank_list_size)

        save_rank_list_to_im(rank_list, same_id, q_im_path, g_im_paths,
                             save_path)