Esempio n. 1
0
def train(
    train_dataloader,
    query_dataloader,
    retrieval_dataloader,
    arch,
    code_length,
    device,
    lr,
    max_iter,
    mu,
    nu,
    eta,
    topk,
    evaluate_interval,
):
    """
    Training model.

    Args
        train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader.
        arch(str): CNN model name.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter: int
        Maximum iteration
        mu, nu, eta(float): Hyper-parameters.
        topk(int): Compute mAP using top k retrieval result
        evaluate_interval(int): Evaluation interval.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Construct network, optimizer, loss
    model = load_model(arch, code_length).to(device)
    criterion = DSDHLoss(eta)
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7)

    # Initialize
    N = len(train_dataloader.dataset)
    B = torch.randn(code_length, N).sign().to(device)
    U = torch.zeros(code_length, N).to(device)
    train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
    S = (train_targets @ train_targets.t() > 0).float()
    Y = train_targets.t()
    best_map = 0.
    iter_time = time.time()

    for it in range(max_iter):
        model.train()
        # CNN-step
        for data, targets, index in train_dataloader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()

            U_batch = model(data).t()
            U[:, index] = U_batch.data
            loss = criterion(U_batch, U, S[:, index], B[:, index])

            loss.backward()
            optimizer.step()
        scheduler.step()

        # W-step
        W = torch.inverse(B @ B.t() + nu / mu *
                          torch.eye(code_length, device=device)) @ B @ Y.t()

        # B-step
        B = solve_dcc(W, Y, U, B, eta, mu)

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            iter_time = time.time() - iter_time
            epoch_loss = calc_loss(U, S, Y, W, B, mu, nu, eta)

            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length,
                                       device)
            retrieval_code = generate_code(model, retrieval_dataloader,
                                           code_length, device)
            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets(
            )

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )
            logger.info(
                '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                    it + 1, max_iter, epoch_loss, mAP, iter_time))

            # Save checkpoint
            if best_map < mAP:
                best_map = mAP
                checkpoint = {
                    'qB': query_code,
                    'qL': query_targets,
                    'rB': retrieval_code,
                    'rL': retrieval_targets,
                    'model': model.state_dict(),
                    'map': best_map,
                }
            iter_time = time.time()

    return checkpoint
Esempio n. 2
0
def train(
    train_data,
    train_targets,
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    num_anchor,
    max_iter,
    lamda,
    nu,
    sigma,
    device,
    topk,
):
    """
    Training model.

    Args
        train_data(torch.Tensor): Training data.
        train_targets(torch.Tensor): Training targets.
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): Query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): Retrieval targets.
        code_length(int): Hash code length.
        num_anchor(int): Number of anchors.
        max_iter(int): Number of iterations.
        lamda, nu, sigma(float): Hyper-parameters.
        device(torch.device): GPU or CPU.
        topk(int): Compute mAP using top k retrieval result.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Initialization
    n = train_data.shape[0]
    L = code_length
    m = num_anchor
    t = max_iter
    X = train_data.t()
    Y = train_targets.t()
    B = torch.randn(L, n).sign()

    # Permute data
    perm_index = torch.randperm(n)
    X = X[:, perm_index]
    Y = Y[:, perm_index]

    # Randomly select num_anchor samples from the training data
    anchor = X[:, :m]

    # Map training data via RBF kernel
    phi_x = torch.from_numpy(rbf_kernel(X.numpy().T,
                                        anchor.numpy().T, sigma)).t()

    # Training
    B = B.to(device)
    Y = Y.to(device)
    phi_x = phi_x.to(device)
    for it in range(t):
        # G-Step
        W = torch.pinverse(B @ B.t() + lamda *
                           torch.eye(code_length, device=device)) @ B @ Y.t()

        # F-Step
        P = torch.pinverse(phi_x @ phi_x.t()) @ phi_x @ B.t()
        F_X = P.t() @ phi_x

        # B-Step
        B = solve_dcc(B, W, Y, F_X, nu)

    # Evaluate
    query_code = generate_code(query_data.t(), anchor, P, sigma)
    retrieval_code = generate_code(retrieval_data.t(), anchor, P, sigma)

    # Compute map
    mAP = mean_average_precision(
        query_code.t().to(device),
        retrieval_code.t().to(device),
        query_targets.to(device),
        retrieval_targets.to(device),
        device,
        topk,
    )

    # PR curve
    Precision, R = pr_curve(
        query_code.t().to(device),
        retrieval_code.t().to(device),
        query_targets.to(device),
        retrieval_targets.to(device),
        device,
    )

    # Save checkpoint
    checkpoint = {
        'tB': B,
        'tL': train_targets,
        'qB': query_code,
        'qL': query_targets,
        'rB': retrieval_code,
        'rL': retrieval_targets,
        'anchor': anchor,
        'projection': P,
        'P': Precision,
        'R': R,
        'map': mAP,
    }

    return checkpoint
Esempio n. 3
0
def train(
    query_dataloader,
    retrieval_dataloader,
    code_length,
    args,
    # args.device,
    # lr,
    # args.max_iter,
    # args.max_epoch,
    # args.num_samples,
    # args.batch_size,
    # args.root,
    # dataset,
    # args.gamma,
    # topk,
):
    """
    Training model.

    Args
        query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        args.device(torch.args.device): GPU or CPU.
        lr(float): Learning rate.
        args.max_iter(int): Number of iterations.
        args.max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        args.batch_size(int): Batch size.
        args.root(str): Path of dataset.
        dataset(str): Dataset name.
        args.gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        mAP(float): Mean Average Precision.
    """
    # Initialization
    # model = alexnet.load_model(code_length).to(args.device)
    model = resnet.resnet50(pretrained=args.pretrain,
                            num_classes=code_length).to(args.device)
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.wd,
                              momentum=args.momen,
                              nesterov=args.nesterov)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step)
    criterion = ADSH_Loss(code_length, args.gamma)

    num_retrieval = len(retrieval_dataloader.dataset)
    U = torch.zeros(args.num_samples, code_length).to(args.device)
    B = torch.randn(num_retrieval, code_length).to(args.device)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(
        args.device)
    cnn_losses, hash_losses, quan_losses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    start = time.time()
    best_mAP = 0
    for it in range(args.max_iter):
        iter_start = time.time()
        # Sample training data for cnn learning
        train_dataloader, sample_index = sample_dataloader(
            retrieval_dataloader, args.num_samples, args.batch_size, args.root,
            args.dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(
            args.device)
        S = (train_targets @ retrieval_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(args.max_epoch):
            cnn_losses.reset()
            hash_losses.reset()
            quan_losses.reset()
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(args.device), targets.to(
                    args.device), index.to(args.device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss, hash_loss, quan_loss = criterion(
                    F, B, S[index, :], sample_index[index])
                cnn_losses.update(cnn_loss.item())
                hash_losses.update(hash_loss.item())
                quan_losses.update(quan_loss.item())
                cnn_loss.backward()
                optimizer.step()
            logger.info(
                '[epoch:{}/{}][cnn_loss:{:.6f}][hash_loss:{:.6f}][quan_loss:{:.6f}]'
                .format(epoch + 1, args.max_epoch, cnn_losses.avg,
                        hash_losses.avg, quan_losses.avg))
        scheduler.step()
        # Update B
        expand_U = torch.zeros(B.shape).to(args.device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, args.gamma)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma)
        # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start))
        logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format(
            it + 1, args.max_iter, iter_loss,
            time.time() - iter_start))

        # Evaluate
        if (it + 1) % 1 == 0:
            query_code = generate_code(model, query_dataloader, code_length,
                                       args.device)
            mAP = evaluate.mean_average_precision(
                query_code.to(args.device),
                B,
                query_dataloader.dataset.get_onehot_targets().to(args.device),
                retrieval_targets,
                args.device,
                args.topk,
            )
            if mAP > best_mAP:
                best_mAP = mAP
                # Save checkpoints
                ret_path = os.path.join('checkpoints', args.info,
                                        str(code_length))
                # ret_path = 'checkpoints/' + args.info
                if not os.path.exists(ret_path):
                    os.makedirs(ret_path)
                torch.save(query_code.cpu(),
                           os.path.join(ret_path, 'query_code.t'))
                torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t'))
                torch.save(query_dataloader.dataset.get_onehot_targets,
                           os.path.join(ret_path, 'query_targets.t'))
                torch.save(retrieval_targets.cpu(),
                           os.path.join(ret_path, 'database_targets.t'))
                torch.save(model.cpu(), os.path.join(ret_path, 'model.t'))
                model = model.to(args.device)
            logger.info(
                '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'.
                format(it + 1, args.max_iter, code_length, mAP, best_mAP))
    logger.info('[Training time:{:.2f}]'.format(time.time() - start))

    return best_mAP
Esempio n. 4
0
def train(
    query_dataloader,
    retrieval_dataloader,
    code_length,
    device,
    lr,
    max_iter,
    max_epoch,
    num_samples,
    batch_size,
    root,
    dataset,
    parameters,
    topk,
):
    """
    Training model.

    Args
        query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        mAP(float): Mean Average Precision.
    """

    # parameters = {'eta':2, 'mu':0.2, 'gamma':1, 'varphi':200}
    eta = parameters['eta']
    mu = parameters['mu']
    gamma = parameters['gamma']
    varphi = parameters['varphi']

    # Initialization
    model = alexnet.load_model(code_length).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = DADSH2_Loss(code_length, eta, mu, gamma, varphi, device)

    num_retrieval = len(retrieval_dataloader.dataset)
    U = torch.zeros(num_samples,
                    code_length).to(device)  # U (m*l, l:code_length)
    B = torch.randn(num_retrieval,
                    code_length).to(device)  # V (n*l, l:code_length)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(
        device)  # Y2 (n*c, c:classes)

    start = time.time()
    for it in range(max_iter):
        iter_start = time.time()
        # Sample training data for cnn learning
        train_dataloader, sample_index = sample_dataloader(
            retrieval_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(
            device)  # Y1 (m*c, c:classes)
        S = (train_targets @ retrieval_targets.t() >
             0).float()  # S (m*n, c:classes)
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(
                    device), index.to(device)
                optimizer.zero_grad()

                F = model(data)  # output (m1*l)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], sample_index[index])

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(B.shape).to(device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, varphi)

        # Total loss
        # iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma)

        # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, max_iter, iter_loss, time.time()-iter_start))
    logger.info('[Training time:{:.2f}]'.format(time.time() - start))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_targets,
        device,
        topk,
    )

    # Save checkpoints
    torch.save(query_code.cpu(), os.path.join('checkpoints', 'query_code.t'))
    torch.save(B.cpu(), os.path.join('checkpoints', 'database_code.t'))
    torch.save(query_dataloader.dataset.get_onehot_targets,
               os.path.join('checkpoints', 'query_targets.t'))
    torch.save(retrieval_targets.cpu(),
               os.path.join('checkpoints', 'database_targets.t'))
    torch.save(model.cpu(), os.path.join('checkpoints', 'model.t'))

    return mAP
Esempio n. 5
0
def train(
    train_dataloader,
    query_dataloader,
    retrieval_dataloader,
    arch,
    code_length,
    device,
    eta,
    lr,
    max_iter,
    topk,
    evaluate_interval,
):
    """
    Training model.

    Args
        train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        arch(str): CNN model name.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        eta(float): Hyper-parameter.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        topk(int): Calculate map of top k.
        evaluate_interval(int): Evaluation interval.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Create model, optimizer, criterion, scheduler
    model = load_model(arch, code_length).to(device)
    criterion = DPSHLoss(eta)
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7)

    # Initialization
    N = len(train_dataloader.dataset)
    U = torch.zeros(N, code_length).to(device)
    train_targets = train_dataloader.dataset.get_onehot_targets().to(device)

    # Training
    best_map = 0.0
    iter_time = time.time()
    for it in range(max_iter):
        model.train()
        running_loss = 0.
        for data, targets, index in train_dataloader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()

            S = (targets @ train_targets.t() > 0).float()
            U_cnn = model(data)
            U[index, :] = U_cnn.data
            loss = criterion(U_cnn, U, S)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            iter_time = time.time() - iter_time

            # Generate hash code and one-hot targets
            query_code = generate_code(model, query_dataloader, code_length,
                                       device)
            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_code = generate_code(model, retrieval_dataloader,
                                           code_length, device)
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets(
            )

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Save checkpoint
            if best_map < mAP:
                best_map = mAP
                checkpoint = {
                    'qB': query_code,
                    'qL': query_targets,
                    'rB': retrieval_code,
                    'rL': retrieval_targets,
                    'model': model.state_dict(),
                    'map': best_map,
                }
            logger.info(
                '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                    it + 1,
                    max_iter,
                    running_loss,
                    mAP,
                    iter_time,
                ))
            iter_time = time.time()

    return checkpoint
Esempio n. 6
0
def train(
    train_dataloader,
    query_dataloader,
    retrieval_dataloader,
    arch,
    code_length,
    device,
    lr,
    max_iter,
    alpha,
    topk,
    evaluate_interval,
):
    """
    Training model.

    Args
        train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        arch(str): CNN model name.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        alpha(float): Hyper-parameters.
        topk(int): Compute top k map.
        evaluate_interval(int): Interval of evaluation.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Load model
    model = load_model(arch, code_length).to(device)

    # Create criterion, optimizer, scheduler
    criterion = HashNetLoss(alpha)
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    # In this implementation, I do not use "scaled tanh".
    # It is useless and hard to tune parameters, sometimes it may decrease performance.
    # Refer to https://github.com/thuml/HashNet/issues/29
    for it in range(max_iter):
        tic = time.time()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(
                device), index.to(device)
            optimizer.zero_grad()

            # Create similarity matrix
            S = (targets @ targets.t() > 0).float()
            outputs = model(data)
            loss = criterion(outputs, S)

            running_loss += loss.item()
            loss.backward()
            optimizer.step()
        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length,
                                       device)
            retrieval_code = generate_code(model, retrieval_dataloader,
                                           code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets(
            )

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info(
                '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                    it + 1,
                    max_iter,
                    running_loss / evaluate_interval,
                    mAP,
                    training_time,
                ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
Esempio n. 7
0
File: itq.py Progetto: Sue-syx/L2H
def train(
    train_data,
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    max_iter,
    device,
    topk,
    ):
    """
    Training model.

    Args
        train_data(torch.Tensor): Training data.
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): Query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): Retrieval targets.
        code_length(int): Hash code length.
        max_iter(int): Number of iterations.
        device(torch.device): GPU or CPU.
        topk(int): Calculate top k data points map.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Initialization
    query_data, query_targets, retrieval_data, retrieval_targets = query_data.to(device), query_targets.to(device), retrieval_data.to(device), retrieval_targets.to(device)
    R = torch.randn(code_length, code_length).to(device)
    [U, _, _] = torch.svd(R)
    R = U[:, :code_length]

    # PCA
    pca = PCA(n_components=code_length)
    V = torch.from_numpy(pca.fit_transform(train_data.numpy())).to(device)

    # Training
    for i in range(max_iter):
        V_tilde = V @ R
        B = V_tilde.sign()
        [U, _, VT] = torch.svd(B.t() @ V)
        R = (VT.t() @ U.t())

    # Evaluate
    # Generate query code and retrieval code
    query_code = generate_code(query_data.cpu(), code_length, R, pca)
    retrieval_code = generate_code(retrieval_data.cpu(), code_length, R, pca)

    # Compute map
    mAP = mean_average_precision(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
        topk,
    )

    # P-R curve
    P, Recall = pr_curve(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
    )

    # Save checkpoint
    checkpoint = {
        'qB': query_code,
        'rB': retrieval_code,
        'qL': query_targets,
        'rL': retrieval_targets,
        'pca': pca,
        'rotation_matrix': R,
        'P': P,
        'R': Recall,
        'map': mAP,
    }

    return checkpoint
Esempio n. 8
0
def train(
    query_dataloader, train_dataloader, retrieval_dataloader, code_length, args
    # device,
    # lr,
    # args.max_iter,
    # args.max_epoch,
    # args.num_samples,
    # args.batch_size,
    # args.root,
    # dataset,
    # args.gamma,
    # args.topk,
):
    """
    Training model.

    Args
        query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        args.max_iter(int): Number of iterations.
        args.max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        args.batch_size(int): Batch size.
        args.root(str): Path of dataset.
        dataset(str): Dataset name.
        args.gamma(float): Hyper-parameters.
        args.topk(int): args.Topk k map.

    Returns
        mAP(float): Mean Average Precision.
    """
    # Initialization
    # model = alexnet.load_model(code_length).to(device)
    # model = resnet.resnet50(pretrained=True, num_classes=code_length).to(device)
    num_classes, att_size, feat_size = args.num_classes, 4, 2048
    model = exchnet.exchnet(code_length=code_length,
                            num_classes=num_classes,
                            att_size=att_size,
                            feat_size=feat_size,
                            device=args.device,
                            pretrained=args.pretrain).to(args.device)
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.wd,
                              momentum=args.momen,
                              nesterov=args.nesterov)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, args.lr_step)
    criterion = Exch_Loss(code_length, args.device, lambd_sp=1.0, lambd_ch=1.0)

    criterion.quanting = args.quan_loss
    num_retrieval = len(retrieval_dataloader.dataset)
    U = torch.zeros(args.num_samples, code_length).to(args.device)
    B = torch.randn(num_retrieval, code_length).to(args.device)
    # B = torch.zeros(num_retrieval, code_length).to(args.device)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(
        args.device)
    C = torch.zeros((num_classes, att_size, feat_size)).to(args.device)
    start = time.time()
    best_mAP = 0
    for it in range(args.max_iter):
        iter_start = time.time()
        # Sample training data for cnn learning
        train_dataloader, sample_index = sample_dataloader(
            retrieval_dataloader, args.num_samples, args.batch_size, args.root,
            args.dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(
            args.device)
        S = (train_targets @ retrieval_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r
        cnn_losses, hash_losses, quan_losses,  sp_losses, ch_losses, align_losses = AverageMeter(), \
            AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()

        # Training CNN model
        for epoch in range(args.max_epoch):
            cnn_losses.reset()
            hash_losses.reset()
            quan_losses.reset()
            sp_losses.reset()
            ch_losses.reset()
            align_losses.reset()
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(args.device), targets.to(
                    args.device), index.to(args.device)
                optimizer.zero_grad()
                F, sp_v, ch_v, avg_local_f = model(data, targets)
                U[index, :] = F.data
                batch_anchor_local_f = C[torch.argmax(targets, dim=1)]
                # print(index)
                cnn_loss, hash_loss, quan_loss, sp_loss, ch_loss, align_loss = criterion(
                    F, B, S[index, :], sample_index[index], sp_v, ch_v,
                    avg_local_f, batch_anchor_local_f)
                cnn_losses.update(cnn_loss.item())
                hash_losses.update(hash_loss.item())
                quan_losses.update(quan_loss.item())
                sp_losses.update(sp_loss.item())
                ch_losses.update(ch_loss.item())
                align_losses.update(align_loss.item())
                # print(ch_v)
                cnn_loss.backward()
                optimizer.step()
            logger.info(
                '[epoch:{}/{}][cnn_loss:{:.6f}][h_loss:{:.6f}][q_loss:{:.6f}][s_loss:{:.4f}][c_loss:{:.4f}][a_loss:{:.4f}]'
                .format(epoch + 1, args.max_epoch, cnn_losses.avg,
                        hash_losses.avg, quan_losses.avg, sp_losses.avg,
                        ch_losses.avg, align_losses.avg))
        scheduler.step()
        # Update B
        expand_U = torch.zeros(B.shape).to(args.device)
        expand_U[sample_index, :] = U
        if args.quan_loss:
            B = solve_dcc_adsh(B, U, expand_U, S, code_length, args.gamma)
        else:
            B = solve_dcc_exch(B, U, expand_U, S, code_length, args.gamma)

        # Update C
        if (it + 1) >= args.align_step:
            model.exchanging = True
            criterion.aligning = True
            model.eval()
            with torch.no_grad():
                C = torch.zeros(
                    (num_classes, att_size, feat_size)).to(args.device)
                feat_cnt = torch.zeros((num_classes, 1, 1)).to(args.device)
                for batch, (data, targets,
                            index) in enumerate(retrieval_dataloader):
                    data, targets, index = data.to(args.device), targets.to(
                        args.device), index.to(args.device)
                    _, _, _, avg_local_f = model(data, targets)
                    class_idx = targets.argmax(dim=1)
                    for i in range(targets.shape[0]):
                        C[class_idx[i]] += avg_local_f[i]
                        feat_cnt[class_idx[i]] += 1
                C /= feat_cnt
                model.anchor_local_f = C
            model.train()

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, args.gamma)
        # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, args.max_iter, iter_loss, time.time()-iter_start))
        logger.info('[iter:{}/{}][loss:{:.6f}][iter_time:{:.2f}]'.format(
            it + 1, args.max_iter, iter_loss,
            time.time() - iter_start))

        # Evaluate
        if (it + 1) % 1 == 0:
            query_code = generate_code(model, query_dataloader, code_length,
                                       args.device)
            mAP = evaluate.mean_average_precision(
                query_code.to(args.device),
                B,
                query_dataloader.dataset.get_onehot_targets().to(args.device),
                retrieval_targets,
                args.device,
                args.topk,
            )
            if mAP > best_mAP:
                best_mAP = mAP
                # Save checkpoints
                ret_path = os.path.join('checkpoints', args.info,
                                        str(code_length))
                if not os.path.exists(ret_path):
                    os.makedirs(ret_path)
                torch.save(query_code.cpu(),
                           os.path.join(ret_path, 'query_code.t'))
                torch.save(B.cpu(), os.path.join(ret_path, 'database_code.t'))
                torch.save(query_dataloader.dataset.get_onehot_targets,
                           os.path.join(ret_path, 'query_targets.t'))
                torch.save(retrieval_targets.cpu(),
                           os.path.join(ret_path, 'database_targets.t'))
                torch.save(model.cpu(), os.path.join(ret_path, 'model.t'))
                model = model.to(args.device)
            logger.info(
                '[iter:{}/{}][code_length:{}][mAP:{:.5f}][best_mAP:{:.5f}]'.
                format(it + 1, args.max_iter, code_length, mAP, best_mAP))

    logger.info('[Training time:{:.2f}]'.format(time.time() - start))

    return best_mAP
Esempio n. 9
0
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr,
          max_iter, topk, evaluate_interval, anchor_num, proportion
          ):
    rho1 = 1e-2
    #ρ1
    rho2 = 1e-2
    #ρ2
    rho3 = 1e-3
    #µ1
    rho4 = 1e-3
    #µ2
    gamma = 1e-3
    #γ
    with torch.no_grad():
        data_mo = torch.tensor([]).to(device)
        for data, _, _ in train_dataloader:
            data = data.to(device)
            data_mo = torch.cat((data_mo, data), 0)
            torch.cuda.empty_cache()
        n = data_mo.size(1)
        Y1 = torch.rand(n, code_length).to(device)
        Y2 = torch.rand(n, code_length).to(device)
        B=torch.rand(n,code_length).to(device)
    # Load model
    model = load_model(arch, code_length).to(device)
    # Create criterion, optimizer, scheduler
    criterion = PrototypicalLoss()
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    for it in range(max_iter):
        # timer
        tic = time.time()

        # ADMM use anchors in first step but drop them later
        '''
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)
            for data, _, _ in train_dataloader:
                data = data.to(device)
                output_mo_temp = model_mo(data)
                output_mo = torch.cat((output_mo, output_mo_temp), 0)
                torch.cuda.empty_cache()
            anchor = get_anchor(output_mo, anchor_num, device)  # compute anchor
        '''
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)

            for data, _, _ in train_dataloader:
                output_B, output_A = model(data)
                output_mo = torch.cat((output_mo, output_A), 0)
                torch.cuda.empty_cache()

            dist = euclidean_dist(output_mo, output_mo)
            dist = torch.exp(-1 * dist / torch.max(dist)).to(device)
            A = (2 / (torch.max(dist) - torch.min(dist))) * dist - 1
            global_A=A.numpy()
            Z1 = B + 1 / rho1 * Y1
            Z1[Z1 > 1] = 1
            Z1[Z1 > -1] = -1
            Z2 = B + 1 / rho2 * Y2
            norm_B = torch.norm(Z2)
            Z2 = torch.sqrt(n * code_length) * Z2 / norm_B
            Y1 = Y1 + gamma * rho1 * (B - Z1)
            Y2 = Y2 + gamma * rho2 * (B - Z2)
            global_Z1=Z1.numpy()
            global_Z2=Z2.numpy()
            global_Y1 = Y1.numpy()
            global_Y2 = Y2.numpy()
            B0 = B.numpy()
            B= torch.from_numpy(scipy.optimize.fmin_l_bfgs_b(Baim_func, B0)).to(device)
        # self-supervised deep learning
        model.train()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(device), index.to(device)
            optimizer.zero_grad()

            # output_B for hash code .output_A for result without hash layer
            output_B, output_A= model(data)

            loss = criterion(output_B, B)

            running_loss += loss.item()
            loss.backward()

            optimizer.step()

        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length, device)
            retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                it + 1,
                max_iter,
                running_loss / evaluate_interval,
                mAP,
                training_time,
            ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
Esempio n. 10
0
def increment(
        query_dataloader,
        unseen_dataloader,
        retrieval_dataloader,
        old_B,
        code_length,
        device,
        lr,
        max_iter,
        max_epoch,
        num_samples,
        batch_size,
        root,
        dataset,
        gamma,
        mu,
        topk,
):
    """
    Increment model.

    Args
        query_dataloader, unseen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        old_B(torch.Tensor): Old binary hash code.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma, mu(float): Hyper-parameters.
        topk(int): Top k map.

    Returns
        mAP(float): Mean Average Precision.
    """
    # Initialization
    model = alexnet.load_model(code_length)
    model.to(device)
    model.train()
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    criterion = DIHN_Loss(code_length, gamma, mu)
    lr_scheduler = ExponentialLR(optimizer, 0.91)

    num_unseen = len(unseen_dataloader.dataset)
    num_seen = len(old_B)
    U = torch.zeros(num_samples, code_length).to(device)
    old_B = old_B.to(device)
    new_B = torch.randn(num_unseen, code_length).sign().to(device)
    B = torch.cat((old_B, new_B), dim=0).to(device)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(device)

    total_time = time.time()
    for it in range(max_iter):
        iter_time = time.time()
        lr_scheduler.step()

        # Sample training data for cnn learning
        train_dataloader, sample_index, unseen_sample_in_unseen_index, unseen_sample_in_sample_index = sample_dataloader(retrieval_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
        S = (train_targets @ retrieval_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(device), index.to(device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], index)

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(num_unseen, code_length).to(device)
        expand_U[unseen_sample_in_unseen_index, :] = U[unseen_sample_in_sample_index, :]
        new_B = solve_dcc(new_B, U, expand_U, S[:, unseen_dataloader.dataset.UNSEEN_INDEX], code_length, gamma)
        B = torch.cat((old_B, new_B), dim=0).to(device)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma, mu)
        logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time))

    logger.info('[DIHN time:{:.2f}]'.format(time.time() - total_time))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_targets,
        device,
        topk,
    )

    return mAP
Esempio n. 11
0
def train(
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    device,
    topk,
):
    """
    Training model

    Args
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): One-hot query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): One-hot retrieval targets.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        topk(int): Calculate top k data map.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Initialization
    query_data, retrieval_data, query_targets, retrieval_targets = query_data.to(
        device), retrieval_data.to(device), query_targets.to(
            device), retrieval_targets.to(device)

    # Generate random projection matrix
    W = torch.randn(query_data.shape[1], code_length).to(device)

    # Generate query and retrieval code
    query_code = (query_data @ W).sign()
    retrieval_code = (retrieval_data @ W).sign()

    # Compute map
    mAP = mean_average_precision(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
        topk,
    )

    # P-R curve
    P, R = pr_curve(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
    )

    # Save checkpoint
    checkpoint = {
        'qB': query_code,
        'rB': retrieval_code,
        'qL': query_targets,
        'rL': retrieval_targets,
        'W': W,
        'P': P,
        'R': R,
        'map': mAP,
    }
    torch.save(checkpoint,
               'checkpoints/code_{}_map_{:.4f}.pt'.format(code_length, mAP))

    return checkpoint
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr,
          max_iter, topk, evaluate_interval, anchor_num, proportion
          ):
    #print("using device")
    #print(torch.cuda.current_device())
    #print(torch.cuda.get_device_name(torch.cuda.current_device()))
    # Load model
    model = load_model(arch, code_length).to(device)
    model_mo = load_model_mo(arch).to(device)

    # Create criterion, optimizer, scheduler
    criterion = PrototypicalLoss()
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    for it in range(max_iter):
        # timer
        tic = time.time()

        # harvest prototypes/anchors#some times killed, try another way
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)
            for data, _, _ in train_dataloader:
                data = data.to(device)
                output_mo_temp = model_mo(data)
                output_mo = torch.cat((output_mo, output_mo_temp), 0)
                torch.cuda.empty_cache()

            anchor = get_anchor(output_mo, anchor_num, device)  # compute anchor

        # self-supervised deep learning
        model.train()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(device), index.to(device)
            optimizer.zero_grad()

            # output
            output_B = model(data)
            output_mo_batch = model_mo(data)

            # prototypes/anchors based similarity

            #sample_anchor_distance = torch.sqrt(torch.sum((output_mo_batch[:, None, :] - anchor) ** 2, dim=2)).to(device)
            #sample_anchor_dist_normalize = F.normalize(sample_anchor_distance, p=2, dim=1).to(device)
            #S = sample_anchor_dist_normalize @ sample_anchor_dist_normalize.t()

            # loss
            #loss = criterion(output_B, S)
            #running_loss = running_loss + loss.item()
            #loss.backward(retain_graph=True)
            with torch.no_grad():
                dist = torch.sum((output_mo_batch[:, None, :] - anchor.to(device)) ** 2, dim=2)
                k = dist.size(1)
                dist = torch.exp(-1 * dist / torch.max(dist)).to(device)
                Z_su = torch.ones(k, 1).to(device)
                Z_sum = torch.sqrt(dist.mm(Z_su)) + 1e-12
                Z_simi = torch.div(dist, Z_sum).to(device)
                S = (Z_simi.mm(Z_simi.t()))
                S=(2/(torch.max(S)-torch.min(S)))*S-1


            loss = criterion(output_B, S)

            running_loss += loss.item()
            loss.backward()

            optimizer.step()
        with torch.no_grad():
            # momentum update:
            for param_q, param_k in zip(model.parameters(), model_mo.parameters()):
                param_k.data = param_k.data * proportion + param_q.data * (1. - proportion)  # proportion = 0.999 for update

        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length, device)
            retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                it + 1,
                max_iter,
                running_loss / evaluate_interval,
                mAP,
                training_time,
            ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
Esempio n. 13
0
def train(
        query_dataloader,
        seen_dataloader,
        retrieval_dataloader,
        code_length,
        device,
        lr,
        max_iter,
        max_epoch,
        num_samples,
        batch_size,
        root,
        dataset,
        gamma,
        topk,
):
    """
    Training model.

    Args
        query_dataloader, seen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_samples(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        None
    """
    # Initialization
    model = alexnet.load_model(code_length).to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    criterion = ADSH_Loss(code_length, gamma)
    lr_scheduler = ExponentialLR(optimizer, 0.9)

    num_seen = len(seen_dataloader.dataset)
    U = torch.zeros(num_samples, code_length).to(device)
    B = torch.randn(num_seen, code_length).sign().to(device)
    seen_targets = seen_dataloader.dataset.get_onehot_targets().to(device)

    total_time = time.time()
    for it in range(max_iter):
        iter_time = time.time()
        lr_scheduler.step()

        # Sample training data for cnn learning
        train_dataloader, sample_index, _, _ = sample_dataloader(seen_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
        S = (train_targets @ seen_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(device), index.to(device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], index)

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(B.shape).to(device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, gamma)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma)
        logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time))

    logger.info('Training adsh finish, time:{:.2f}'.format(time.time()-total_time))

    # Save checkpoints
    torch.save(B.cpu(), os.path.join('checkpoints', 'old_B.t'))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_dataloader.dataset.get_onehot_targets().to(device),
        device,
        topk,
    )
    logger.info('[ADSH map:{:.4f}]'.format(mAP))
Esempio n. 14
0
def train(
    train_data,
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    device,
    topk,
):
    """
    Training model.

    Args
        train_data(torch.Tensor): Training data.
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): Query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): Retrieval targets.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        topk(int): Calculate top k data points map.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # PCA
    pca = PCA(n_components=code_length)
    X = pca.fit_transform(train_data.numpy())

    # Fit uniform distribution
    eps = np.finfo(float).eps
    mn = X.min(0) - eps
    mx = X.max(0) + eps

    # Enumerate eigenfunctions
    R = mx - mn
    max_mode = np.ceil((code_length + 1) * R / R.max()).astype(np.int)
    n_modes = max_mode.sum() - len(max_mode) + 1
    modes = np.ones([n_modes, code_length])
    m = 0
    for i in range(code_length):
        modes[m + 1:m + max_mode[i], i] = np.arange(1, max_mode[i]) + 1
        m = m + max_mode[i] - 1

    modes -= 1
    omega0 = np.pi / R
    omegas = modes * omega0.reshape(1, -1).repeat(n_modes, 0)
    eig_val = -(omegas**2).sum(1)
    ii = (-eig_val).argsort()
    modes = modes[ii[1:code_length + 1], :]

    # Evaluate
    # Generate query code and retrieval code
    query_code = generate_code(query_data.cpu(), code_length, pca, mn, R,
                               modes).to(device)
    retrieval_code = generate_code(retrieval_data.cpu(), code_length, pca, mn,
                                   R, modes).to(device)
    query_targets = query_targets.to(device)
    retrieval_targets = retrieval_targets.to(device)

    # Compute map
    mAP = mean_average_precision(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
        topk,
    )

    # P-R curve
    P, Recall = pr_curve(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
    )

    # Save checkpoint
    checkpoint = {
        'qB': query_code,
        'rB': retrieval_code,
        'qL': query_targets,
        'rL': retrieval_targets,
        'P': P,
        'R': Recall,
        'map': mAP,
    }

    return checkpoint