Example #1
0
def train(
    query_dataloader,
    retrieval_dataloader,
    code_length,
    args,
    # args.device,
    # lr,
    # args.max_iter,
    # args.max_epoch,
    # args.num_samples,
    # args.batch_size,
    # args.root,
    # dataset,
    # args.gamma,
    # topk,
):
    """
    Training model.

    Args
        query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        args.device(torch.args.device): GPU or CPU.
        lr(float): Learning rate.
        args.max_iter(int): Number of iterations.
        args.max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        args.batch_size(int): Batch size.
        args.root(str): Path of dataset.
        dataset(str): Dataset name.
        args.gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        mAP(float): Mean Average Precision.
    """
    # Initialization
    # model = alexnet.load_model(code_length).to(args.device)
    model = resnet.resnet50(pretrained=args.pretrain,
                            num_classes=code_length).to(args.device)
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.wd,
                              momentum=args.momen,
                              nesterov=args.nesterov)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step)
    criterion = ADSH_Loss(code_length, args.gamma)

    num_retrieval = len(retrieval_dataloader.dataset)
    U = torch.zeros(args.num_samples, code_length).to(args.device)
    B = torch.randn(num_retrieval, code_length).to(args.device)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(
        args.device)
    cnn_losses, hash_losses, quan_losses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    start = time.time()
    best_mAP = 0
    for it in range(args.max_iter):
        iter_start = time.time()
        # Sample training data for cnn learning
        train_dataloader, sample_index = sample_dataloader(
            retrieval_dataloader, args.num_samples, args.batch_size, args.root,
            args.dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(
            args.device)
        S = (train_targets @ retrieval_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(args.max_epoch):
            cnn_losses.reset()
            hash_losses.reset()
            quan_losses.reset()
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(args.device), targets.to(
                    args.device), index.to(args.device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss, hash_loss, quan_loss = criterion(
                    F, B, S[index, :], sample_index[index])
                cnn_losses.update(cnn_loss.item())
                hash_losses.update(hash_loss.item())
                quan_losses.update(quan_loss.item())
                cnn_loss.backward()
                optimizer.step()
            logger.info(
                '[epoch:{}/{}][cnn_loss:{:.6f}][hash_loss:{:.6f}][quan_loss:{:.6f}]'
                .format(epoch + 1, args.max_epoch, cnn_losses.avg,
                        hash_losses.avg, quan_losses.avg))
        scheduler.step()
        # Update B
        expand_U = torch.zeros(B.shape).to(args.device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, args.gamma)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma)
        # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start))
        logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format(
            it + 1, args.max_iter, iter_loss,
            time.time() - iter_start))

        # Evaluate
        if (it + 1) % 1 == 0:
            query_code = generate_code(model, query_dataloader, code_length,
                                       args.device)
            mAP = evaluate.mean_average_precision(
                query_code.to(args.device),
                B,
                query_dataloader.dataset.get_onehot_targets().to(args.device),
                retrieval_targets,
                args.device,
                args.topk,
            )
            if mAP > best_mAP:
                best_mAP = mAP
                # Save checkpoints
                ret_path = os.path.join('checkpoints', args.info,
                                        str(code_length))
                # ret_path = 'checkpoints/' + args.info
                if not os.path.exists(ret_path):
                    os.makedirs(ret_path)
                torch.save(query_code.cpu(),
                           os.path.join(ret_path, 'query_code.t'))
                torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t'))
                torch.save(query_dataloader.dataset.get_onehot_targets,
                           os.path.join(ret_path, 'query_targets.t'))
                torch.save(retrieval_targets.cpu(),
                           os.path.join(ret_path, 'database_targets.t'))
                torch.save(model.cpu(), os.path.join(ret_path, 'model.t'))
                model = model.to(args.device)
            logger.info(
                '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'.
                format(it + 1, args.max_iter, code_length, mAP, best_mAP))
    logger.info('[Training time:{:.2f}]'.format(time.time() - start))

    return best_mAP
Example #2
0
def train(
    query_dataloader,
    retrieval_dataloader,
    code_length,
    device,
    lr,
    max_iter,
    max_epoch,
    num_samples,
    batch_size,
    root,
    dataset,
    parameters,
    topk,
):
    """
    Training model.

    Args
        query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        mAP(float): Mean Average Precision.
    """

    # parameters = {'eta':2, 'mu':0.2, 'gamma':1, 'varphi':200}
    eta = parameters['eta']
    mu = parameters['mu']
    gamma = parameters['gamma']
    varphi = parameters['varphi']

    # Initialization
    model = alexnet.load_model(code_length).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = DADSH2_Loss(code_length, eta, mu, gamma, varphi, device)

    num_retrieval = len(retrieval_dataloader.dataset)
    U = torch.zeros(num_samples,
                    code_length).to(device)  # U (m*l, l:code_length)
    B = torch.randn(num_retrieval,
                    code_length).to(device)  # V (n*l, l:code_length)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(
        device)  # Y2 (n*c, c:classes)

    start = time.time()
    for it in range(max_iter):
        iter_start = time.time()
        # Sample training data for cnn learning
        train_dataloader, sample_index = sample_dataloader(
            retrieval_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(
            device)  # Y1 (m*c, c:classes)
        S = (train_targets @ retrieval_targets.t() >
             0).float()  # S (m*n, c:classes)
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(
                    device), index.to(device)
                optimizer.zero_grad()

                F = model(data)  # output (m1*l)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], sample_index[index])

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(B.shape).to(device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, varphi)

        # Total loss
        # iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma)

        # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, max_iter, iter_loss, time.time()-iter_start))
    logger.info('[Training time:{:.2f}]'.format(time.time() - start))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_targets,
        device,
        topk,
    )

    # Save checkpoints
    torch.save(query_code.cpu(), os.path.join('checkpoints', 'query_code.t'))
    torch.save(B.cpu(), os.path.join('checkpoints', 'database_code.t'))
    torch.save(query_dataloader.dataset.get_onehot_targets,
               os.path.join('checkpoints', 'query_targets.t'))
    torch.save(retrieval_targets.cpu(),
               os.path.join('checkpoints', 'database_targets.t'))
    torch.save(model.cpu(), os.path.join('checkpoints', 'model.t'))

    return mAP
Example #3
0
def increment(
        query_dataloader,
        unseen_dataloader,
        retrieval_dataloader,
        old_B,
        code_length,
        device,
        lr,
        max_iter,
        max_epoch,
        num_samples,
        batch_size,
        root,
        dataset,
        gamma,
        mu,
        topk,
):
    """
    Increment model.

    Args
        query_dataloader, unseen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        old_B(torch.Tensor): Old binary hash code.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma, mu(float): Hyper-parameters.
        topk(int): Top k map.

    Returns
        mAP(float): Mean Average Precision.
    """
    # Initialization
    model = alexnet.load_model(code_length)
    model.to(device)
    model.train()
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    criterion = DIHN_Loss(code_length, gamma, mu)
    lr_scheduler = ExponentialLR(optimizer, 0.91)

    num_unseen = len(unseen_dataloader.dataset)
    num_seen = len(old_B)
    U = torch.zeros(num_samples, code_length).to(device)
    old_B = old_B.to(device)
    new_B = torch.randn(num_unseen, code_length).sign().to(device)
    B = torch.cat((old_B, new_B), dim=0).to(device)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(device)

    total_time = time.time()
    for it in range(max_iter):
        iter_time = time.time()
        lr_scheduler.step()

        # Sample training data for cnn learning
        train_dataloader, sample_index, unseen_sample_in_unseen_index, unseen_sample_in_sample_index = sample_dataloader(retrieval_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
        S = (train_targets @ retrieval_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(device), index.to(device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], index)

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(num_unseen, code_length).to(device)
        expand_U[unseen_sample_in_unseen_index, :] = U[unseen_sample_in_sample_index, :]
        new_B = solve_dcc(new_B, U, expand_U, S[:, unseen_dataloader.dataset.UNSEEN_INDEX], code_length, gamma)
        B = torch.cat((old_B, new_B), dim=0).to(device)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma, mu)
        logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time))

    logger.info('[DIHN time:{:.2f}]'.format(time.time() - total_time))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_targets,
        device,
        topk,
    )

    return mAP
Example #4
0
def train(
    query_dataloader, train_dataloader, retrieval_dataloader, code_length, args
    # device,
    # lr,
    # args.max_iter,
    # args.max_epoch,
    # args.num_samples,
    # args.batch_size,
    # args.root,
    # dataset,
    # args.gamma,
    # args.topk,
):
    """
    Training model.

    Args
        query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        args.max_iter(int): Number of iterations.
        args.max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        args.batch_size(int): Batch size.
        args.root(str): Path of dataset.
        dataset(str): Dataset name.
        args.gamma(float): Hyper-parameters.
        args.topk(int): args.Topk k map.

    Returns
        mAP(float): Mean Average Precision.
    """
    # Initialization
    # model = alexnet.load_model(code_length).to(device)
    # model = resnet.resnet50(pretrained=True, num_classes=code_length).to(device)
    num_classes, att_size, feat_size = args.num_classes, 4, 2048
    model = exchnet.exchnet(code_length=code_length,
                            num_classes=num_classes,
                            att_size=att_size,
                            feat_size=feat_size,
                            device=args.device,
                            pretrained=args.pretrain).to(args.device)
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.wd,
                              momentum=args.momen,
                              nesterov=args.nesterov)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step)
    criterion = Exch_Loss(code_length, args.device, lambd_sp=1.0, lambd_ch=1.0)

    criterion.quanting = args.quan_loss
    num_retrieval = len(retrieval_dataloader.dataset)
    U = torch.zeros(args.num_samples, code_length).to(args.device)
    B = torch.randn(num_retrieval, code_length).to(args.device)
    # B = torch.zeros(num_retrieval, code_length).to(args.device)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(
        args.device)
    C = torch.zeros((num_classes, att_size, feat_size)).to(args.device)
    start = time.time()
    best_mAP = 0
    for it in range(args.max_iter):
        iter_start = time.time()
        # Sample training data for cnn learning
        train_dataloader, sample_index = sample_dataloader(
            retrieval_dataloader, args.num_samples, args.batch_size, args.root,
            args.dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(
            args.device)
        S = (train_targets @ retrieval_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r
        cnn_losses, hash_losses, quan_losses,  sp_losses, ch_losses, align_losses = AverageMeter(), \
            AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()

        # Training CNN model
        for epoch in range(args.max_epoch):
            cnn_losses.reset()
            hash_losses.reset()
            quan_losses.reset()
            sp_losses.reset()
            ch_losses.reset()
            align_losses.reset()
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(args.device), targets.to(
                    args.device), index.to(args.device)
                optimizer.zero_grad()
                F, sp_v, ch_v, avg_local_f = model(data, targets)
                U[index, :] = F.data
                batch_anchor_local_f = C[torch.argmax(targets, dim=1)]
                # print(index)
                cnn_loss, hash_loss, quan_loss, sp_loss, ch_loss, align_loss = criterion(
                    F, B, S[index, :], sample_index[index], sp_v, ch_v,
                    avg_local_f, batch_anchor_local_f)
                cnn_losses.update(cnn_loss.item())
                hash_losses.update(hash_loss.item())
                quan_losses.update(quan_loss.item())
                sp_losses.update(sp_loss.item())
                ch_losses.update(ch_loss.item())
                align_losses.update(align_loss.item())
                # print(ch_v)
                cnn_loss.backward()
                optimizer.step()
            logger.info(
                '[epoch:{}/{}][cnn_loss:{:.6f}][h_loss:{:.6f}][q_loss:{:.6f}][s_loss:{:.4f}][c_loss:{:.4f}][a_loss:{:.4f}]'
                .format(epoch + 1, args.max_epoch, cnn_losses.avg,
                        hash_losses.avg, quan_losses.avg, sp_losses.avg,
                        ch_losses.avg, align_losses.avg))
        scheduler.step()
        # Update B
        expand_U = torch.zeros(B.shape).to(args.device)
        expand_U[sample_index, :] = U
        if args.quan_loss:
            B = solve_dcc_adsh(B, U, expand_U, S, code_length, args.gamma)
        else:
            B = solve_dcc_exch(B, U, expand_U, S, code_length, args.gamma)

        # Update C
        if (it + 1) >= args.align_step:
            model.exchanging = True
            criterion.aligning = True
            model.eval()
            with torch.no_grad():
                C = torch.zeros(
                    (num_classes, att_size, feat_size)).to(args.device)
                feat_cnt = torch.zeros((num_classes, 1, 1)).to(args.device)
                for batch, (data, targets,
                            index) in enumerate(retrieval_dataloader):
                    data, targets, index = data.to(args.device), targets.to(
                        args.device), index.to(args.device)
                    _, _, _, avg_local_f = model(data, targets)
                    class_idx = targets.argmax(dim=1)
                    for i in range(targets.shape[0]):
                        C[class_idx[i]] += avg_local_f[i]
                        feat_cnt[class_idx[i]] += 1
                C /= feat_cnt
                model.anchor_local_f = C
            model.train()

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma)
        # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start))
        logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format(
            it + 1, args.max_iter, iter_loss,
            time.time() - iter_start))

        # Evaluate
        if (it + 1) % 1 == 0:
            query_code = generate_code(model, query_dataloader, code_length,
                                       args.device)
            mAP = evaluate.mean_average_precision(
                query_code.to(args.device),
                B,
                query_dataloader.dataset.get_onehot_targets().to(args.device),
                retrieval_targets,
                args.device,
                args.topk,
            )
            if mAP > best_mAP:
                best_mAP = mAP
                # Save checkpoints
                ret_path = os.path.join('checkpoints', args.info,
                                        str(code_length))
                if not os.path.exists(ret_path):
                    os.makedirs(ret_path)
                torch.save(query_code.cpu(),
                           os.path.join(ret_path, 'query_code.t'))
                torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t'))
                torch.save(query_dataloader.dataset.get_onehot_targets,
                           os.path.join(ret_path, 'query_targets.t'))
                torch.save(retrieval_targets.cpu(),
                           os.path.join(ret_path, 'database_targets.t'))
                torch.save(model.cpu(), os.path.join(ret_path, 'model.t'))
                model = model.to(args.device)
            logger.info(
                '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'.
                format(it + 1, args.max_iter, code_length, mAP, best_mAP))

    logger.info('[Training time:{:.2f}]'.format(time.time() - start))

    return best_mAP
Example #5
0
def train(
        query_dataloader,
        seen_dataloader,
        retrieval_dataloader,
        code_length,
        device,
        lr,
        max_iter,
        max_epoch,
        num_samples,
        batch_size,
        root,
        dataset,
        gamma,
        topk,
):
    """
    Training model.

    Args
        query_dataloader, seen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_samples(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        None
    """
    # Initialization
    model = alexnet.load_model(code_length).to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    criterion = ADSH_Loss(code_length, gamma)
    lr_scheduler = ExponentialLR(optimizer, 0.9)

    num_seen = len(seen_dataloader.dataset)
    U = torch.zeros(num_samples, code_length).to(device)
    B = torch.randn(num_seen, code_length).sign().to(device)
    seen_targets = seen_dataloader.dataset.get_onehot_targets().to(device)

    total_time = time.time()
    for it in range(max_iter):
        iter_time = time.time()
        lr_scheduler.step()

        # Sample training data for cnn learning
        train_dataloader, sample_index, _, _ = sample_dataloader(seen_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
        S = (train_targets @ seen_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(device), index.to(device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], index)

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(B.shape).to(device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, gamma)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma)
        logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time))

    logger.info('Training adsh finish, time:{:.2f}'.format(time.time()-total_time))

    # Save checkpoints
    torch.save(B.cpu(), os.path.join('checkpoints', 'old_B.t'))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_dataloader.dataset.get_onehot_targets().to(device),
        device,
        topk,
    )
    logger.info('[ADSH map:{:.4f}]'.format(mAP))