Пример #1
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    if not cfg.only_test:
        train_set = create_dataset(**cfg.train_set_kwargs)
        # The combined dataset does not provide val set currently.
        val_set = None if cfg.dataset == 'combined' else create_dataset(
            **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(last_conv_stride=cfg.last_conv_stride)
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    tri_loss = TripletLoss(margin=cfg.margin)

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims,
                                      cfg.ckpt_file.format(200))

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n=========> Test on validation set <=========\n')
        mAP, cmc_scores, _, _ = val_set.eval(
            normalize_feat=cfg.normalize_feature,
            to_re_rank=False,
            verbose=False)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        # For recording precision, satisfying margin, etc
        prec_meter = AverageMeter()
        sm_meter = AverageMeter()
        dist_ap_meter = AverageMeter()
        dist_an_meter = AverageMeter()
        dist_rec_meter = AverageMeter()
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, masks, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            masks_var = Variable(TVT(torch.from_numpy(masks).float()))
            labels_t = TVT(torch.from_numpy(labels).long())

            # feat = model_w(ims_var, masks_var)
            feat = model_w(ims_var, masks_var)

            loss, p_inds, n_inds, rec_inds, dist_ap, dist_an, dist_rec, dist_mat = global_loss(
                tri_loss,
                feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            # precision
            prec = (dist_an > dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            sm = (dist_an > dist_ap + cfg.margin).data.float().mean()
            # average (anchor, positive) distance
            d_ap = dist_ap.data.mean()
            # average (anchor, negative) distance
            d_an = dist_an.data.mean()
            # average rec distance
            d_rec = dist_rec.data.mean()

            prec_meter.update(prec)
            sm_meter.update(sm)
            dist_ap_meter.update(d_ap)
            dist_an_meter.update(d_an)
            dist_rec_meter.update(d_rec)
            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                tri_log = (', prec {:.2%}, sm {:.2%}, '
                           'd_ap {:.4f}, d_an {:.4f}, d_rec {:.4f}, '
                           'loss {:.4f}'.format(
                               prec_meter.val,
                               sm_meter.val,
                               dist_ap_meter.val,
                               dist_an_meter.val,
                               dist_rec_meter.val,
                               loss_meter.val,
                           ))

                log = time_log + tri_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st)

        tri_log = (', prec {:.2%}, sm {:.2%}, '
                   'd_ap {:.4f}, d_an {:.4f}, d_rec {:.4f}, '
                   'loss {:.4f}'.format(
                       prec_meter.avg,
                       sm_meter.avg,
                       dist_ap_meter.avg,
                       dist_an_meter.avg,
                       dist_rec_meter.avg,
                       loss_meter.avg,
                   ))

        log = time_log + tri_log
        print(log)

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # save ckpt
        if cfg.log_to_file and ((ep + 1) % 50 == 0):
            print('Saving model for epoch {}'.format(ep + 1))
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file.format(ep + 1))

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)
            writer.add_scalars('precision', dict(precision=prec_meter.avg, ),
                               ep)
            writer.add_scalars('satisfy_margin',
                               dict(satisfy_margin=sm_meter.avg, ), ep)
            writer.add_scalars(
                'average_distance',
                dict(
                    dist_ap=dist_ap_meter.avg,
                    dist_an=dist_an_meter.avg,
                ), ep)

    ########
    # Test #
    ########

    test(load_model_weight=False)
Пример #2
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    if not cfg.only_test:
        train_set = create_dataset(**cfg.train_set_kwargs)
        # The combined dataset does not provide val set currently.
        val_set = None if cfg.dataset == 'combined' else create_dataset(
            **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    # model = Model(last_conv_stride=cfg.last_conv_stride)
    if cfg.dataset == 'market1501':
        nr_class = 751
    elif cfg.dataset == 'duke':
        nr_class = 702
    elif cfg.dataset == 'cuhk03':
        nr_class = 767
    elif cfg.dataset == 'combined':
        nr_class = 2220
    model = get_scp_model(nr_class)

    # load pretrained ImageNet weights
    model_dict = model.state_dict()  # original
    pretrained_dict = torch.load('models/resnet50-19c8e357.pth')  # pretrained
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and model_dict[k].size() == v.size()
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    criterion_cls = torch.nn.CrossEntropyLoss()
    criterion_feature = torch.nn.MSELoss()

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            mAP, cmc_scores, mq_mAP, mq_cmc_scores = test_set.eval(
                normalize_feat=cfg.normalize_feature, verbose=True)
            return mAP, cmc_scores, mq_mAP, mq_cmc_scores

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n=========> Test on validation set <=========\n')
        mAP, cmc_scores, _, _ = val_set.eval(
            normalize_feat=cfg.normalize_feature,
            to_re_rank=False,
            verbose=False)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        mAP, cmc_scores, mq_mAP, mq_cmc_scores = test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        # if cfg.lr_decay_type == 'exp':
        #     adjust_lr_exp(
        #         optimizer,
        #         cfg.base_lr,
        #         ep + 1,
        #         cfg.total_epochs,
        #         cfg.exp_decay_at_epoch)
        # else:
        #     adjust_lr_staircase(
        #         optimizer,
        #         cfg.base_lr,
        #         ep + 1,
        #         cfg.staircase_decay_at_epochs,
        #         cfg.staircase_decay_multiply_factor)

        #
        if ep < 20:
            lr = 1e-4 * (ep + 1) / 2
        elif ep < 80:
            lr = 1e-3
        elif 80 <= ep <= 180:
            lr = 1e-4
        else:
            lr = 1e-5
        for g in optimizer.param_groups:
            g['lr'] = lr

        may_set_mode(modules_optims, 'train')

        # For recording precision, satisfying margin, etc
        prec_meter = AverageMeter()
        sm_meter = AverageMeter()
        dist_ap_meter = AverageMeter()
        dist_an_meter = AverageMeter()
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = Variable(TVT(torch.from_numpy(labels).long()))

            feat = model_w(ims_var)

            loss, prec = scp_loss(feat, labels_t, criterion_cls,
                                  criterion_feature, cfg.ids_per_batch,
                                  cfg.ims_per_id)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            prec_meter.update(prec)
            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                tri_log = (', prec {:.2%}, sm {:.2%}, '
                           'd_ap {:.4f}, d_an {:.4f}, '
                           'loss {:.4f}'.format(
                               prec_meter.val,
                               sm_meter.val,
                               dist_ap_meter.val,
                               dist_an_meter.val,
                               loss_meter.val,
                           ))

                log = time_log + tri_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st)

        tri_log = (', prec {:.2%}, sm {:.2%}, '
                   'd_ap {:.4f}, d_an {:.4f}, '
                   'loss {:.4f}'.format(
                       prec_meter.avg,
                       sm_meter.avg,
                       dist_ap_meter.avg,
                       dist_an_meter.avg,
                       loss_meter.avg,
                   ))

        log = time_log + tri_log
        print(log)

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)
            writer.add_scalars('precision', dict(precision=prec_meter.avg, ),
                               ep)
            writer.add_scalars('satisfy_margin',
                               dict(satisfy_margin=sm_meter.avg, ), ep)
            writer.add_scalars(
                'average_distance',
                dict(
                    dist_ap=dist_ap_meter.avg,
                    dist_an=dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)