def main_worker(args):
    global start_epoch, best_mAP

    cudnn.benchmark = True

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    else:
        log_dir = osp.dirname(args.resume)
        sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    iters = args.iters if (args.iters > 0) else None
    dataset_source, num_classes, train_loader_source, test_loader_source = \
        get_data(args.dataset_source, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers, args.num_instances, iters)

    dataset_target, _, train_loader_target, test_loader_target = \
        get_data(args.dataset_target, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers, 0, iters)

    # Create model
    model = models.create(args.arch,
                          num_features=args.features,
                          dropout=args.dropout,
                          num_classes=[num_classes])
    model.cuda()
    model = nn.DataParallel(model)
    print(model)
    # Load from checkpoint
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        copy_state_dict(checkpoint['state_dict'], model)
        start_epoch = checkpoint['epoch']
        best_mAP = checkpoint['best_mAP']
        print("=> Start epoch {}  best mAP {:.1%}".format(
            start_epoch, best_mAP))

    # Evaluator
    evaluator = Evaluator(model)
    # args.evaluate=True
    if args.evaluate:
        print("Test on source domain:")
        evaluator.evaluate(test_loader_source,
                           dataset_source.query,
                           dataset_source.gallery,
                           cmc_flag=True,
                           rerank=args.rerank)
        print("Test on target domain:")
        evaluator.evaluate(test_loader_target,
                           dataset_target.query,
                           dataset_target.gallery,
                           cmc_flag=True,
                           rerank=args.rerank)
        return

    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            continue
        params += [{
            "params": [value],
            "lr": args.lr,
            "weight_decay": args.weight_decay
        }]
    optimizer = torch.optim.Adam(params)
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     args.milestones,
                                     gamma=0.1,
                                     warmup_factor=0.01,
                                     warmup_iters=args.warmup_step)

    # Trainer
    trainer = PreTrainer(model, num_classes, margin=args.margin)

    # Start training
    for epoch in range(start_epoch, args.epochs):
        lr_scheduler.step()
        train_loader_source.new_epoch()
        train_loader_target.new_epoch()

        trainer.train(epoch,
                      train_loader_source,
                      train_loader_target,
                      optimizer,
                      train_iters=len(train_loader_source),
                      print_freq=args.print_freq)

        if ((epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1)):

            _, mAP = evaluator.evaluate(test_loader_source,
                                        dataset_source.query,
                                        dataset_source.gallery,
                                        cmc_flag=True)

            is_best = mAP > best_mAP
            best_mAP = max(mAP, best_mAP)
            save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'epoch': epoch + 1,
                    'best_mAP': best_mAP,
                },
                is_best,
                fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

            print(
                '\n * Finished epoch {:3d}  source mAP: {:5.1%}  best: {:5.1%}{}\n'
                .format(epoch, mAP, best_mAP, ' *' if is_best else ''))

    print("Test on target domain:")
    evaluator.evaluate(test_loader_target,
                       dataset_target.query,
                       dataset_target.gallery,
                       cmc_flag=True,
                       rerank=args.rerank)
def main_worker(args):
    global start_epoch, best_mAP

    cudnn.benchmark = True

    sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    iters = args.iters if (args.iters > 0) else None
    ncs = [int(x) for x in args.ncs.split(',')]

    dataset_target, label_dict = get_data(args.dataset_target, args.data_dir,
                                          len(ncs), True)
    test_loader_target = get_test_loader(dataset_target, args.height,
                                         args.width, args.batch_size,
                                         args.workers)
    tar_cluster_loader = get_test_loader(dataset_target,
                                         args.height,
                                         args.width,
                                         args.batch_size,
                                         args.workers,
                                         testset=dataset_target.train)

    dataset_source, _ = get_data(args.dataset_source, args.data_dir, len(ncs))
    sour_cluster_loader = get_test_loader(dataset_source,
                                          args.height,
                                          args.width,
                                          args.batch_size,
                                          args.workers,
                                          testset=dataset_source.train)
    train_loader_source = get_train_loader(dataset_source, args.height,
                                           args.width, 0, args.batch_size,
                                           args.workers, args.num_instances,
                                           args.iters, dataset_source.train)

    fc_len = 3500
    model_1, _, model_1_ema, _ = create_model(
        args, [fc_len for _ in range(len(ncs))])

    epoch = 0
    target_features_dict, _ = extract_features(model_1_ema,
                                               tar_cluster_loader,
                                               print_freq=100)
    target_features = F.normalize(torch.stack(
        list(target_features_dict.values())),
                                  dim=1)

    # Calculate distance
    print('==> Create pseudo labels for unlabeled target domain')

    rerank_dist = compute_jaccard_distance(target_features,
                                           k1=args.k1,
                                           k2=args.k2)
    del target_features
    if (epoch == 0):
        # DBSCAN cluster
        eps = 0.6  # 0.6
        print('Clustering criterion: eps: {:.3f}'.format(eps))
        cluster = DBSCAN(eps=eps,
                         min_samples=4,
                         metric='precomputed',
                         n_jobs=-1)

    # select & cluster images as training set of this epochs
    pseudo_labels = cluster.fit_predict(rerank_dist)

    # num_ids = len(set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0)
    plabel = []
    new_dataset = []
    for i, (item, label) in enumerate(zip(dataset_target.train,
                                          pseudo_labels)):
        if label == -1:
            continue
        plabel.append(label)
        new_dataset.append((item[0], label, item[-1]))

    target_label = [plabel]
    ncs = [len(set(plabel)) + 1]
    print('new class are {}, length of new dataset is {}'.format(
        ncs, len(new_dataset)))

    # Initialize source-domain class centroids
    print("==> Initialize source-domain class centroids in the hybrid memory")
    source_features, _, _ = extract_features(model_1,
                                             sour_cluster_loader,
                                             print_freq=50)
    sour_fea_dict = collections.defaultdict(list)
    print("==> Ending source-domain class centroids in the hybrid memory")
    for f, pid, _ in sorted(dataset_source.train):
        sour_fea_dict[pid].append(source_features[f].unsqueeze(0))
    source_centers = [
        torch.cat(sour_fea_dict[pid], 0).mean(0)
        for pid in sorted(sour_fea_dict.keys())
    ]
    source_centers = torch.stack(source_centers, 0)
    source_centers = F.normalize(source_centers, dim=1)
    del sour_fea_dict, source_features, sour_cluster_loader

    # Evaluator
    evaluator_1 = Evaluator(model_1)
    evaluator_1_ema = Evaluator(model_1_ema)

    clusters = [args.num_clusters] * args.epochs  # TODO: dropout clusters

    source_classes = dataset_source.num_train_pids
    k_memory = 8192
    contrast = onlinememory(2048,
                            len(new_dataset),
                            sour_numclass=source_classes,
                            K=k_memory + source_classes,
                            index2label=target_label,
                            choice_c=args.choice_c,
                            T=0.07,
                            use_softmax=True).cuda()
    contrast.index_memory = torch.cat(
        (torch.arange(source_classes), -1 * torch.ones(k_memory).long()),
        dim=0).cuda()
    contrast.memory = torch.cat((source_centers, torch.rand(k_memory, 2048)),
                                dim=0).cuda()

    tar_selflabel_loader = get_test_loader(dataset_target,
                                           args.height,
                                           args.width,
                                           args.batch_size,
                                           args.workers,
                                           testset=new_dataset)

    o = Optimizer(target_label,
                  dis_gt=None,
                  m=model_1,
                  ncl=ncs,
                  t_loader=tar_selflabel_loader,
                  N=len(new_dataset),
                  fc_len=fc_len)

    uncertainty = collections.defaultdict(list)
    print("Training begining~~~~~~!!!!!!!!!")
    for epoch in range(len(clusters)):

        iters_ = 300 if epoch % 1 == 0 else iters
        if epoch % 6 == 0 and epoch != 0:
            target_features_dict, _, prob = extract_features(
                model_1_ema, tar_cluster_loader, print_freq=50)

            target_features = torch.stack(
                list(target_features_dict.values())
            )  # torch.cat([target_features[f[0]].unsqueeze(0) for f in dataset_target.train], 0)
            target_features = F.normalize(target_features, dim=1)

            print('==> Create pseudo labels for unlabeled target domain with')
            rerank_dist = compute_jaccard_distance(target_features,
                                                   k1=args.k1,
                                                   k2=args.k2)

            # select & cluster images as training set of this epochs
            pseudo_labels = cluster.fit_predict(rerank_dist)
            num_ids = len(
                set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0)
            plabel = []

            new_dataset = []

            for i, (item, label) in enumerate(
                    zip(dataset_target.train, pseudo_labels)):
                if label == -1:
                    continue
                plabel.append(label)
                new_dataset.append((item[0], label, item[-1]))

            target_label = [plabel]
            ncs = [len(set(plabel)) + 1]

            tar_selflabel_loader = get_test_loader(dataset_target,
                                                   args.height,
                                                   args.width,
                                                   args.batch_size,
                                                   args.workers,
                                                   testset=new_dataset)
            o = Optimizer(target_label,
                          dis_gt=None,
                          m=model_1,
                          ncl=ncs,
                          t_loader=tar_selflabel_loader,
                          N=len(new_dataset),
                          fc_len=fc_len)

        target_label_o = o.L
        target_label = [
            list(np.asarray(target_label_o[0].data.cpu()) + source_classes)
        ]
        contrast.index2label = [[i for i in range(source_classes)] +
                                target_label[0]]

        for i in range(len(new_dataset)):
            new_dataset[i] = list(new_dataset[i])
            for j in range(len(ncs)):
                new_dataset[i][j + 1] = int(target_label[j][i])
            new_dataset[i] = tuple(new_dataset[i])

        cc = args.choice_c  #(args.choice_c+1)%len(ncs)
        train_loader_target = get_train_loader(dataset_target, args.height,
                                               args.width, cc, args.batch_size,
                                               args.workers,
                                               args.num_instances, iters_,
                                               new_dataset)

        # Optimizer
        params = []
        flag = 1.0
        # if 20<epoch<=40 or 60<epoch<=80 or 120<epoch:
        #     flag=0.1
        # else:
        #     flag=1.0

        for key, value in model_1.named_parameters():
            if not value.requires_grad:
                print(key)
                continue
            params += [{
                "params": [value],
                "lr": args.lr * flag,
                "weight_decay": args.weight_decay
            }]

        optimizer = torch.optim.Adam(params)

        # Trainer
        trainer = DbscanBaseTrainer(model_1,
                                    model_1_ema,
                                    contrast,
                                    None,
                                    None,
                                    num_cluster=ncs,
                                    c_name=ncs,
                                    alpha=args.alpha,
                                    fc_len=fc_len,
                                    source_classes=source_classes,
                                    uncer_mode=args.uncer_mode)

        train_loader_target.new_epoch()
        train_loader_source.new_epoch()

        trainer.train(epoch,
                      train_loader_target,
                      train_loader_source,
                      optimizer,
                      args.choice_c,
                      lambda_tri=args.lambda_tri,
                      lambda_ct=args.lambda_ct,
                      lambda_reg=args.lambda_reg,
                      print_freq=args.print_freq,
                      train_iters=iters_,
                      uncertainty_d=uncertainty)

        def save_model(model_ema, is_best, best_mAP, mid):
            save_checkpoint(
                {
                    'state_dict': model_ema.state_dict(),
                    'epoch': epoch + 1,
                    'best_mAP': best_mAP,
                },
                is_best,
                fpath=osp.join(args.logs_dir,
                               'model' + str(mid) + '_checkpoint.pth.tar'))

        if epoch == 20:
            args.eval_step = 2
        elif epoch == 50:
            args.eval_step = 1
        if ((epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1)):
            mAP_1 = 0  #evaluator_1.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery,
            #          cmc_flag=False)

            mAP_2 = evaluator_1_ema.evaluate(test_loader_target,
                                             dataset_target.query,
                                             dataset_target.gallery,
                                             cmc_flag=False)
            is_best = (mAP_1 > best_mAP) or (mAP_2 > best_mAP)
            best_mAP = max(mAP_1, mAP_2, best_mAP)
            save_model(model_1, (is_best), best_mAP, 1)
            save_model(model_1_ema, (is_best and (mAP_1 <= mAP_2)), best_mAP,
                       2)

            print(
                '\n * Finished epoch {:3d}  model no.1 mAP: {:5.1%} model no.2 mAP: {:5.1%}  best: {:5.1%}{}\n'
                .format(epoch, mAP_1, mAP_2, best_mAP,
                        ' *' if is_best else ''))

    print('Test on the best model.')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model_1.load_state_dict(checkpoint['state_dict'])
    evaluator_1.evaluate(test_loader_target,
                         dataset_target.query,
                         dataset_target.gallery,
                         cmc_flag=True)
def main_worker(args):
    global start_epoch, best_mAP

    cudnn.benchmark = True

    sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    iters = args.iters if (args.iters > 0) else None
    ncs = [int(x) for x in args.ncs.split(',')]
    # ncs_dbscan=ncs.copy()
    dataset_target, label_dict = get_data(args.dataset_target, args.data_dir,
                                          len(ncs))
    dataset_source, _ = get_data(args.dataset_source, args.data_dir, len(ncs))

    test_loader_target = get_test_loader(dataset_target, args.height,
                                         args.width, args.batch_size,
                                         args.workers)

    tar_cluster_loader = get_test_loader(dataset_target,
                                         args.height,
                                         args.width,
                                         args.batch_size,
                                         args.workers,
                                         testset=dataset_target.train)

    fc_len = 3500
    model_1, _, model_1_ema, _ = create_model(
        args, [fc_len for _ in range(len(ncs))])
    print(model_1)

    #target_label = np.load("target_label.npy")
    epoch = 0
    target_features_dict, _ = extract_features(model_1_ema,
                                               tar_cluster_loader,
                                               print_freq=100)

    target_features = torch.stack(
        list(target_features_dict.values())
    )  #torch.cat([target_features[f[0]].unsqueeze(0) for f in dataset_target.train], 0)
    target_features = F.normalize(target_features, dim=1)
    # Calculate distance
    print('==> Create pseudo labels for unlabeled target domain')

    rerank_dist = compute_jaccard_distance(target_features,
                                           k1=args.k1,
                                           k2=args.k2)
    del target_features
    if (epoch == 0):
        # DBSCAN cluster
        eps = 0.6  # 0.6
        print('Clustering criterion: eps: {:.3f}'.format(eps))
        cluster = DBSCAN(eps=eps,
                         min_samples=4,
                         metric='precomputed',
                         n_jobs=-1)

    # select & cluster images as training set of this epochs
    pseudo_labels = cluster.fit_predict(rerank_dist)

    num_ids = len(set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0)

    p1 = []

    new_dataset = []
    for i, (item, label) in enumerate(zip(dataset_target.train,
                                          pseudo_labels)):
        if label == -1: continue
        p1.append(label)
        new_dataset.append((item[0], label, item[-1]))
    target_label = [p1]
    ncs = [len(set(p1)) + 1]

    print('new class are {}, length of new dataset is {}'.format(
        ncs, len(new_dataset)))

    # Evaluator
    evaluator_1 = Evaluator(model_1)
    evaluator_1_ema = Evaluator(model_1_ema)

    # evaluator_1.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery,
    #                      cmc_flag=True)
    # evaluator_1_ema.evaluate(test_loader_target, dataset_target.query, dataset_target.gallery,
    #                      cmc_flag=True)
    clusters = [args.num_clusters] * args.epochs  # TODO: dropout clusters

    print("Training begining~~~~~~!!!!!!!!!")
    for epoch in range(len(clusters)):

        iters_ = 300 if epoch % 1 == 0 else iters
        if epoch % 6 == 0 and epoch != 0:
            target_features_dict, _ = extract_features(model_1_ema,
                                                       tar_cluster_loader,
                                                       print_freq=50)

            target_features = torch.stack(
                list(target_features_dict.values())
            )  # torch.cat([target_features[f[0]].unsqueeze(0) for f in dataset_target.train], 0)
            target_features = F.normalize(target_features, dim=1)
            # Calculate distance
            print('==> Create pseudo labels for unlabeled target domain with')
            rerank_dist = compute_jaccard_distance(target_features,
                                                   k1=args.k1,
                                                   k2=args.k2)

            # select & cluster images as training set of this epochs
            pseudo_labels = cluster.fit_predict(rerank_dist)
            num_ids = len(
                set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0)

            p1 = []

            new_dataset = []

            for i, (item, label) in enumerate(
                    zip(dataset_target.train, pseudo_labels)):
                if label == -1:
                    continue
                p1.append(label)
                new_dataset.append((item[0], label, item[-1]))
            target_label = [p1]
            ncs = [len(set(p1)) + 1]

            print('new class are {}, length of new dataset is {}'.format(
                ncs, len(new_dataset)))

            obj = collections.Counter(pseudo_labels)
            print("The number of label is {}".format(obj))

        target_label = [target_label[0]]

        # change pseudo labels
        for i in range(len(new_dataset)):
            new_dataset[i] = list(new_dataset[i])
            for j in range(len(ncs)):
                new_dataset[i][j + 1] = int(target_label[j][i])
            new_dataset[i] = tuple(new_dataset[i])

        # print(nc,"============"+str(iters_))
        cc = args.choice_c  #(args.choice_c+1)%len(ncs)
        train_loader_target = get_train_loader(dataset_target, args.height,
                                               args.width, cc, args.batch_size,
                                               args.workers,
                                               args.num_instances, iters_,
                                               new_dataset)

        # Optimizer
        params = []
        flag = 1.0
        # if 20<epoch<=40 or 60<epoch<=80 or 120<epoch:
        #     flag=0.1
        # else:
        #     flag=1.0

        for key, value in model_1.named_parameters():
            if not value.requires_grad:
                print(key)
                continue
            params += [{
                "params": [value],
                "lr": args.lr * flag,
                "weight_decay": args.weight_decay
            }]
        # for key, value in model_2.named_parameters():
        #     if not value.requires_grad:
        #         continue
        #     params += [{"params": [value], "lr": args.lr, "weight_decay": args.weight_decay}]

        optimizer = torch.optim.Adam(params)

        # Trainer
        trainer = DbscanBaseTrainer(model_1,
                                    model_1_ema,
                                    num_cluster=ncs,
                                    c_name=ncs,
                                    alpha=args.alpha,
                                    fc_len=fc_len)

        train_loader_target.new_epoch()
        # index2label = dict([(i, j) for i, j in enumerate(np.asarray(target_label[0]))])
        # index2label1= dict([(i, j) for i, j in enumerate(np.asarray(target_label[1]))])
        # index2label2 = dict([(i, j) for i, j in enumerate(np.asarray(target_label[2]))])

        trainer.train(epoch,
                      train_loader_target,
                      optimizer,
                      args.choice_c,
                      ce_soft_weight=args.soft_ce_weight,
                      tri_soft_weight=args.soft_tri_weight,
                      print_freq=args.print_freq,
                      train_iters=iters_)

        # if epoch>20:
        # o.optimize_labels()

        # ecn.L = o.L

        # if nc ==yhua[-1]:
        #     while nc ==yhua[-1]:
        #         target_label_o = o.optimize_labels()
        #         yhua= yhua[:-1]

        def save_model(model_ema, is_best, best_mAP, mid):
            save_checkpoint(
                {
                    'state_dict': model_ema.state_dict(),
                    'epoch': epoch + 1,
                    'best_mAP': best_mAP,
                },
                is_best,
                fpath=osp.join(args.logs_dir,
                               'model' + str(mid) + '_checkpoint.pth.tar'))

        if epoch == 20:
            args.eval_step = 2
        elif epoch == 50:
            args.eval_step = 1
        if ((epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1)):
            mAP_1 = evaluator_1.evaluate(test_loader_target,
                                         dataset_target.query,
                                         dataset_target.gallery,
                                         cmc_flag=False)

            mAP_2 = evaluator_1_ema.evaluate(test_loader_target,
                                             dataset_target.query,
                                             dataset_target.gallery,
                                             cmc_flag=False)
            is_best = (mAP_1 > best_mAP) or (mAP_2 > best_mAP)
            best_mAP = max(mAP_1, mAP_2, best_mAP)
            save_model(model_1, (is_best), best_mAP, 1)
            save_model(model_1_ema, (is_best and (mAP_1 <= mAP_2)), best_mAP,
                       2)

            print(
                '\n * Finished epoch {:3d}  model no.1 mAP: {:5.1%} model no.2 mAP: {:5.1%}  best: {:5.1%}{}\n'
                .format(epoch, mAP_1, mAP_2, best_mAP,
                        ' *' if is_best else ''))

    print('Test on the best model.')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model_1.load_state_dict(checkpoint['state_dict'])
    evaluator_1.evaluate(test_loader_target,
                         dataset_target.query,
                         dataset_target.gallery,
                         cmc_flag=True)