Пример #1
0
def train(cfg, cfg_hr):
    # prepare dataset
    train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
    # prepare model
    model = build_model(cfg, cfg_hr, num_classes)
    model = nn.DataParallel(model)
    if cfg.MODEL.IF_WITH_CENTER == 'no':
        print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
        optimizer = make_optimizer(cfg, model)
        print(cfg.SOLVER.MARGIN)
        scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                      cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        loss_func = make_loss(cfg, num_classes)     # modified by gu

        arguments = {}

        do_train(
            cfg,
            model,
            train_loader,
            val_loader,
            optimizer,
            scheduler,
            loss_func,
            num_query
        )
    elif cfg.MODEL.IF_WITH_CENTER == 'yes':
        print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
        loss_func, center_criterion = make_loss_with_center(cfg, num_classes)  # modified by gu
        optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
        scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                      cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        arguments = {}

        do_train_with_center(
            cfg,
            model,
            center_criterion,
            train_loader,
            val_loader,
            optimizer,
            optimizer_center,
            scheduler,
            loss_func,
            num_query
        )
    else:
        print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER))
Пример #2
0
def train(cfg):
    # prepare dataset
    train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)

    # prepare model
    model = build_model(cfg, num_classes)

    if cfg.MODEL.IF_WITH_CENTER == 'no':
        print('Train without center loss, the loss type is',
              cfg.MODEL.METRIC_LOSS_TYPE)
        optimizer = make_optimizer(cfg, model)
        # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
        #                               cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        loss_func = make_loss(cfg, num_classes)  # modified by gu

        # Add for using self trained model
        if cfg.MODEL.PRETRAIN_CHOICE == 'self':
            start_epoch = eval(
                cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')
                [-1])
            print('Start epoch:', start_epoch)
            path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'optimizer')
            print('Path to the checkpoint of optimizer:', path_to_optimizer)
            model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
            optimizer.load_state_dict(torch.load(path_to_optimizer))
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD,
                                          start_epoch)
        elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
            start_epoch = 0
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD)
        else:
            print(
                'Only support pretrain_choice for imagenet and self, but got {}'
                .format(cfg.MODEL.PRETRAIN_CHOICE))

        arguments = {}

        do_train(
            cfg,
            model,
            train_loader,
            val_loader,
            optimizer,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch  # add for using self trained model
        )
    elif cfg.MODEL.IF_WITH_CENTER == 'yes':
        print('Train with center loss, the loss type is',
              cfg.MODEL.METRIC_LOSS_TYPE)
        loss_func, center_criterion = make_loss_with_center(
            cfg, num_classes)  # modified by gu
        optimizer, optimizer_center = make_optimizer_with_center(
            cfg, model, center_criterion)
        # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
        #                               cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        arguments = {}

        # Add for using self trained model
        if cfg.MODEL.PRETRAIN_CHOICE == 'self':
            start_epoch = eval(
                cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')
                [-1])
            print('Start epoch:', start_epoch)
            path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'optimizer')
            print('Path to the checkpoint of optimizer:', path_to_optimizer)
            path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'center_param')
            print('Path to the checkpoint of center_param:',
                  path_to_center_param)
            path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'optimizer_center')
            print('Path to the checkpoint of optimizer_center:',
                  path_to_optimizer_center)
            model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
            optimizer.load_state_dict(torch.load(path_to_optimizer))
            center_criterion.load_state_dict(torch.load(path_to_center_param))
            optimizer_center.load_state_dict(
                torch.load(path_to_optimizer_center))
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD,
                                          start_epoch)
        elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
            start_epoch = 0
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD)
        else:
            print(
                'Only support pretrain_choice for imagenet and self, but got {}'
                .format(cfg.MODEL.PRETRAIN_CHOICE))

        do_train_with_center(
            cfg,
            model,
            center_criterion,
            train_loader,
            val_loader,
            optimizer,
            optimizer_center,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch  # add for using self trained model
        )
    else:
        print(
            "Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n"
            .format(cfg.MODEL.IF_WITH_CENTER))
Пример #3
0
def train(cfg):
    # prepare dataset
    train_loader, val_loader, num_query, num_classes, clustering_loader = make_data_loader(
        cfg)

    # prepare model
    model = build_model(cfg, num_classes)

    if cfg.MODEL.IF_WITH_CENTER == 'on':
        loss_func, center_criterion_part, center_criterion_global, center_criterion_fore = make_loss_with_center(
            cfg, num_classes)
        optimizer, optimizer_center = make_optimizer_with_center(
            cfg, model, center_criterion_part, center_criterion_global,
            center_criterion_fore)
    else:
        loss_func = make_loss(cfg, num_classes)
        optimizer = make_optimizer(cfg, model)

    # Add for using self trained model
    if cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
        start_epoch = 0
        scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                      cfg.SOLVER.GAMMA,
                                      cfg.SOLVER.WARMUP_FACTOR,
                                      cfg.SOLVER.WARMUP_ITERS,
                                      cfg.SOLVER.WARMUP_METHOD)
    else:
        print('Only support pretrain_choice for imagenet, but got {}'.format(
            cfg.MODEL.PRETRAIN_CHOICE))

    if cfg.MODEL.IF_WITH_CENTER == 'on':
        do_train_with_center(
            cfg,
            model,
            center_criterion_part,
            center_criterion_global,
            center_criterion_fore,
            train_loader,
            val_loader,
            optimizer,
            optimizer_center,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch,  # add for using self trained model
            clustering_loader)
    else:
        do_train(
            cfg,
            model,
            train_loader,
            val_loader,
            optimizer,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch,  # add for using self trained model
            clustering_loader)
Пример #4
0
def train(cfg):
    # prepare dataset
    # train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
    train_loader, val_loader, num_query, num_classes = make_data_loader_train(
        cfg)

    # prepare model
    if 'prw' in cfg.DATASETS.NAMES:
        num_classes = 483
    elif "market1501" in cfg.DATASETS.NAMES:
        num_classes = 751
    elif "duke" in cfg.DATASETS.NAMES:
        num_classes = 702
    elif "cuhk" in cfg.DATASETS.NAMES:
        num_classes = 5532

    model = build_model(cfg, num_classes)

    if cfg.MODEL.IF_WITH_CENTER == 'no':
        print('Train without center loss, the loss type is',
              cfg.MODEL.METRIC_LOSS_TYPE)
        optimizer = make_optimizer(cfg, model)
        # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
        #                               cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        loss_func = make_loss(cfg, num_classes)  # modified by gu

        # Add for using self trained model
        if cfg.MODEL.PRETRAIN_CHOICE == 'self':
            # start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1])
            start_epoch = 0
            print('Start epoch:', start_epoch)
            path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'optimizer')
            print('Path to the checkpoint of optimizer:', path_to_optimizer)

            pretrained_dic = torch.load(cfg.MODEL.PRETRAIN_PATH).state_dict()
            model_dict = model.state_dict()

            model_dict.update(pretrained_dic)
            model.load_state_dict(model_dict)

            if cfg.MODEL.WHOLE_MODEL_TRAIN == "no":
                for name, value in model.named_parameters():
                    if "Query_Guided_Attention" not in name and "non_local" not in name and "classifier_attention" not in name:
                        value.requires_grad = False
                optimizer = make_optimizer(cfg, model)
            # else:
            #     cfg.SOLVER.BASE_LR = 0.0000035

            # optimizer.load_state_dict(torch.load(path_to_optimizer))
            # #####
            # for state in optimizer.state.values():
            #     for k, v in state.items():
            #         if isinstance(v, torch.Tensor):
            #             state[k] = v.cuda()
            # #####
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD)
        elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
            start_epoch = 0
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD)
        else:
            print(
                'Only support pretrain_choice for imagenet and self, but got {}'
                .format(cfg.MODEL.PRETRAIN_CHOICE))

        arguments = {}

        do_train(
            cfg,
            model,
            train_loader,
            val_loader,
            optimizer,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch  # add for using self trained model
        )
    elif cfg.MODEL.IF_WITH_CENTER == 'yes':
        print('Train with center loss, the loss type is',
              cfg.MODEL.METRIC_LOSS_TYPE)
        loss_func, center_criterion = make_loss_with_center(
            cfg, num_classes)  # modified by gu
        optimizer, optimizer_center = make_optimizer_with_center(
            cfg, model, center_criterion)
        # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
        #                               cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        arguments = {}

        # Add for using self trained model
        if cfg.MODEL.PRETRAIN_CHOICE == 'self':
            start_epoch = eval(
                cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')
                [-1])
            print('Start epoch:', start_epoch)
            path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'optimizer')
            print('Path to the checkpoint of optimizer:', path_to_optimizer)
            path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'center_param')
            print('Path to the checkpoint of center_param:',
                  path_to_center_param)
            path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'optimizer_center')
            print('Path to the checkpoint of optimizer_center:',
                  path_to_optimizer_center)
            model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
            optimizer.load_state_dict(torch.load(path_to_optimizer))
            #####
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()
            #####
            center_criterion.load_state_dict(torch.load(path_to_center_param))
            optimizer_center.load_state_dict(
                torch.load(path_to_optimizer_center))
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD,
                                          start_epoch)
        elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
            start_epoch = 0
            scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                          cfg.SOLVER.GAMMA,
                                          cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS,
                                          cfg.SOLVER.WARMUP_METHOD)
        else:
            print(
                'Only support pretrain_choice for imagenet and self, but got {}'
                .format(cfg.MODEL.PRETRAIN_CHOICE))

        do_train_with_center(
            cfg,
            model,
            center_criterion,
            train_loader,
            val_loader,
            optimizer,
            optimizer_center,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch  # add for using self trained model
        )
    else:
        print(
            "Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n"
            .format(cfg.MODEL.IF_WITH_CENTER))
Пример #5
0
def train(cfg):
    # prepare dataset
    train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)

    # prepare model
    model = build_model(cfg, num_classes)

    if cfg.MODEL.IF_WITH_CENTER == "no":
        print("Train without center loss, the loss type is", cfg.MODEL.METRIC_LOSS_TYPE)
        optimizer = make_optimizer(cfg, model)
        # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
        #                               cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        loss_func = make_loss(cfg, num_classes)  # modified by gu

        # Add for using self trained model
        if cfg.MODEL.PRETRAIN_CHOICE == "self":
            start_epoch = eval(
                cfg.MODEL.PRETRAIN_PATH.split("/")[-1].split(".")[0].split("_")[-1]
            )
            print("Start epoch:", start_epoch)
            path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace("model", "optimizer")
            print("Path to the checkpoint of optimizer:", path_to_optimizer)
            model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
            optimizer.load_state_dict(torch.load(path_to_optimizer))
            scheduler = WarmupMultiStepLR(
                optimizer,
                cfg.SOLVER.STEPS,
                cfg.SOLVER.GAMMA,
                cfg.SOLVER.WARMUP_FACTOR,
                cfg.SOLVER.WARMUP_ITERS,
                cfg.SOLVER.WARMUP_METHOD,
                cfg.SOLVER.MODE,
                cfg.SOLVER.MAX_EPOCHS,
                start_epoch
            )
        else:
            start_epoch = 0
            scheduler = WarmupMultiStepLR(
                optimizer,
                cfg.SOLVER.STEPS,
                cfg.SOLVER.GAMMA,
                cfg.SOLVER.WARMUP_FACTOR,
                cfg.SOLVER.WARMUP_ITERS,
                cfg.SOLVER.WARMUP_METHOD,
                cfg.SOLVER.MODE,
                cfg.SOLVER.MAX_EPOCHS
            )

        do_train(
            cfg,
            model,
            train_loader,
            val_loader,
            optimizer,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch,  # add for using self trained model
        )
    elif cfg.MODEL.IF_WITH_CENTER == "yes":
        print("Train with center loss, the loss type is", cfg.MODEL.METRIC_LOSS_TYPE)
        loss_func, center_criterion = make_loss_with_center(
            cfg, num_classes
        )  # modified by gu
        optimizer, optimizer_center = make_optimizer_with_center(
            cfg, model, center_criterion
        )
        # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
        #                               cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

        # Add for using self trained model
        if cfg.MODEL.PRETRAIN_CHOICE == "self":
            start_epoch = eval(
                cfg.MODEL.PRETRAIN_PATH.split("/")[-1].split(".")[0].split("_")[-1]
            )
            print("Start epoch:", start_epoch)
            path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace("model", "optimizer")
            print("Path to the checkpoint of optimizer:", path_to_optimizer)
            path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace(
                "model", "center_param"
            )
            print("Path to the checkpoint of center_param:", path_to_center_param)
            path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace(
                "model", "optimizer_center"
            )
            print(
                "Path to the checkpoint of optimizer_center:", path_to_optimizer_center
            )
            model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
            optimizer.load_state_dict(torch.load(path_to_optimizer))
            center_criterion.load_state_dict(torch.load(path_to_center_param))
            optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
            scheduler = WarmupMultiStepLR(
                optimizer,
                cfg.SOLVER.STEPS,
                cfg.SOLVER.GAMMA,
                cfg.SOLVER.WARMUP_FACTOR,
                cfg.SOLVER.WARMUP_ITERS,
                cfg.SOLVER.WARMUP_METHOD,
                cfg.SOLVER.MODE,
                cfg.SOLVER.MAX_EPOCHS,
                start_epoch,
            )
        else:
            start_epoch = 0
            scheduler = WarmupMultiStepLR(
                optimizer,
                cfg.SOLVER.STEPS,
                cfg.SOLVER.GAMMA,
                cfg.SOLVER.WARMUP_FACTOR,
                cfg.SOLVER.WARMUP_ITERS,
                cfg.SOLVER.WARMUP_METHOD,
                cfg.SOLVER.MODE,
                cfg.SOLVER.MAX_EPOCHS,
            )

        do_train_with_center(
            cfg,
            model,
            center_criterion,
            train_loader,
            val_loader,
            optimizer,
            optimizer_center,
            scheduler,  # modify for using self trained model
            loss_func,
            num_query,
            start_epoch,  # add for using self trained model
        )
    else:
        print(
            "Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(
                cfg.MODEL.IF_WITH_CENTER
            )
        )
Пример #6
0
def train(cfg):
    # prepare dataset
    train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
    target_train_loader, target_val_loader = make_data_loader2(cfg)

    # prepare model
    model = build_model(cfg, num_classes)

    print('Train with center loss, the loss type is',
          cfg.MODEL.METRIC_LOSS_TYPE)
    loss_func, center_criterion = make_loss_with_center(
        cfg, num_classes)  # modified by gu
    cluster_num_classes = cfg.INPUT.CLUSTER_NUMBER
    loss_cluster_func, cluster_criterion = make_loss_with_cluster(
        cfg, cluster_num_classes)

    optimizer, optimizer_center, optimizer_cluster = make_optimizer_with_center2(
        cfg, model, center_criterion, cluster_criterion)
    # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
    #                               cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

    arguments = {}
    """center_criterion = torch.load('D:\download\chromedownload\\resnet50_center_param_30 (1).pth')
    center_criterion.centers = nn.Parameter(center_criterion.centers)
    for param in center_criterion.parameters():
        param.grad.data *= (1. / center_loss_weight)"""

    # Add for using self trained model
    if cfg.MODEL.PRETRAIN_CHOICE == 'self':
        start_epoch = eval(
            cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')
            [-1])
        print('Start epoch:', start_epoch)
        path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace(
            'model', 'optimizer')
        print('Path to the checkpoint of optimizer:', path_to_optimizer)
        path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace(
            'model', 'center_param')
        print('Path to the checkpoint of center_param:', path_to_center_param)
        path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace(
            'model', 'optimizer_center')
        print('Path to the checkpoint of optimizer_center:',
              path_to_optimizer_center)
        model = torch.load(cfg.MODEL.PRETRAIN_PATH)
        optimizer = torch.load(path_to_optimizer)
        center_criterion = torch.load(path_to_center_param)
        optimizer_center = torch.load(path_to_optimizer_center)
        ###
        if start_epoch >= cfg.SOLVER.MY_START_EPOCH:
            path_to_cluster_param = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'cluster_param')
            print('Path to the checkpoint of cluster_param:',
                  path_to_cluster_param)
            path_to_optimizer_cluster = cfg.MODEL.PRETRAIN_PATH.replace(
                'model', 'optimizer_cluster')
            print('Path to the checkpoint of optimizer_cluster:',
                  path_to_optimizer_cluster)
            cluster_criterion = torch.load(path_to_cluster_param)
            optimizer_cluster = torch.load(path_to_optimizer_cluster)
        ###
        scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                      cfg.SOLVER.GAMMA,
                                      cfg.SOLVER.WARMUP_FACTOR,
                                      cfg.SOLVER.WARMUP_ITERS,
                                      cfg.SOLVER.WARMUP_METHOD, start_epoch)
    elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
        start_epoch = 0
        scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                      cfg.SOLVER.GAMMA,
                                      cfg.SOLVER.WARMUP_FACTOR,
                                      cfg.SOLVER.WARMUP_ITERS,
                                      cfg.SOLVER.WARMUP_METHOD)
    else:
        print('Only support pretrain_choice for imagenet and self, but got {}'.
              format(cfg.MODEL.PRETRAIN_CHOICE))
    #trainr2.py
    do_train_with_center2(
        cfg,
        model,
        center_criterion,
        cluster_criterion,  #
        train_loader,
        val_loader,
        target_train_loader,  #
        target_val_loader,  #
        optimizer,
        optimizer_center,
        optimizer_cluster,  #
        scheduler,  # modify for using self trained model
        loss_func,
        loss_cluster_func,  #
        num_query,
        start_epoch,  # add for using self trained model
        cfg.SOLVER.MY_START_EPOCH  #开始聚类损失的EPOCH
    )