Пример #1
0
Файл: itq.py Проект: Sue-syx/L2H
def train(
    train_data,
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    max_iter,
    device,
    topk,
    ):
    """
    Training model.

    Args
        train_data(torch.Tensor): Training data.
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): Query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): Retrieval targets.
        code_length(int): Hash code length.
        max_iter(int): Number of iterations.
        device(torch.device): GPU or CPU.
        topk(int): Calculate top k data points map.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Initialization
    query_data, query_targets, retrieval_data, retrieval_targets = query_data.to(device), query_targets.to(device), retrieval_data.to(device), retrieval_targets.to(device)
    R = torch.randn(code_length, code_length).to(device)
    [U, _, _] = torch.svd(R)
    R = U[:, :code_length]

    # PCA
    pca = PCA(n_components=code_length)
    V = torch.from_numpy(pca.fit_transform(train_data.numpy())).to(device)

    # Training
    for i in range(max_iter):
        V_tilde = V @ R
        B = V_tilde.sign()
        [U, _, VT] = torch.svd(B.t() @ V)
        R = (VT.t() @ U.t())

    # Evaluate
    # Generate query code and retrieval code
    query_code = generate_code(query_data.cpu(), code_length, R, pca)
    retrieval_code = generate_code(retrieval_data.cpu(), code_length, R, pca)

    # Compute map
    mAP = mean_average_precision(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
        topk,
    )

    # P-R curve
    P, Recall = pr_curve(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
    )

    # Save checkpoint
    checkpoint = {
        'qB': query_code,
        'rB': retrieval_code,
        'qL': query_targets,
        'rL': retrieval_targets,
        'pca': pca,
        'rotation_matrix': R,
        'P': P,
        'R': Recall,
        'map': mAP,
    }

    return checkpoint
Пример #2
0
def train(
    train_data,
    train_targets,
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    num_anchor,
    max_iter,
    lamda,
    nu,
    sigma,
    device,
    topk,
):
    """
    Training model.

    Args
        train_data(torch.Tensor): Training data.
        train_targets(torch.Tensor): Training targets.
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): Query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): Retrieval targets.
        code_length(int): Hash code length.
        num_anchor(int): Number of anchors.
        max_iter(int): Number of iterations.
        lamda, nu, sigma(float): Hyper-parameters.
        device(torch.device): GPU or CPU.
        topk(int): Compute mAP using top k retrieval result.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Initialization
    n = train_data.shape[0]
    L = code_length
    m = num_anchor
    t = max_iter
    X = train_data.t()
    Y = train_targets.t()
    B = torch.randn(L, n).sign()

    # Permute data
    perm_index = torch.randperm(n)
    X = X[:, perm_index]
    Y = Y[:, perm_index]

    # Randomly select num_anchor samples from the training data
    anchor = X[:, :m]

    # Map training data via RBF kernel
    phi_x = torch.from_numpy(rbf_kernel(X.numpy().T,
                                        anchor.numpy().T, sigma)).t()

    # Training
    B = B.to(device)
    Y = Y.to(device)
    phi_x = phi_x.to(device)
    for it in range(t):
        # G-Step
        W = torch.pinverse(B @ B.t() + lamda *
                           torch.eye(code_length, device=device)) @ B @ Y.t()

        # F-Step
        P = torch.pinverse(phi_x @ phi_x.t()) @ phi_x @ B.t()
        F_X = P.t() @ phi_x

        # B-Step
        B = solve_dcc(B, W, Y, F_X, nu)

    # Evaluate
    query_code = generate_code(query_data.t(), anchor, P, sigma)
    retrieval_code = generate_code(retrieval_data.t(), anchor, P, sigma)

    # Compute map
    mAP = mean_average_precision(
        query_code.t().to(device),
        retrieval_code.t().to(device),
        query_targets.to(device),
        retrieval_targets.to(device),
        device,
        topk,
    )

    # PR curve
    Precision, R = pr_curve(
        query_code.t().to(device),
        retrieval_code.t().to(device),
        query_targets.to(device),
        retrieval_targets.to(device),
        device,
    )

    # Save checkpoint
    checkpoint = {
        'tB': B,
        'tL': train_targets,
        'qB': query_code,
        'qL': query_targets,
        'rB': retrieval_code,
        'rL': retrieval_targets,
        'anchor': anchor,
        'projection': P,
        'P': Precision,
        'R': R,
        'map': mAP,
    }

    return checkpoint
Пример #3
0
def train(
    train_dataloader,
    query_dataloader,
    retrieval_dataloader,
    arch,
    code_length,
    device,
    lr,
    max_iter,
    alpha,
    topk,
    evaluate_interval,
):
    """
    Training model.

    Args
        train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        arch(str): CNN model name.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        alpha(float): Hyper-parameters.
        topk(int): Compute top k map.
        evaluate_interval(int): Interval of evaluation.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Load model
    model = load_model(arch, code_length).to(device)

    # Create criterion, optimizer, scheduler
    criterion = HashNetLoss(alpha)
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    # In this implementation, I do not use "scaled tanh".
    # It is useless and hard to tune parameters, sometimes it may decrease performance.
    # Refer to https://github.com/thuml/HashNet/issues/29
    for it in range(max_iter):
        tic = time.time()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(
                device), index.to(device)
            optimizer.zero_grad()

            # Create similarity matrix
            S = (targets @ targets.t() > 0).float()
            outputs = model(data)
            loss = criterion(outputs, S)

            running_loss += loss.item()
            loss.backward()
            optimizer.step()
        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length,
                                       device)
            retrieval_code = generate_code(model, retrieval_dataloader,
                                           code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets(
            )

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info(
                '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                    it + 1,
                    max_iter,
                    running_loss / evaluate_interval,
                    mAP,
                    training_time,
                ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
Пример #4
0
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr,
          max_iter, topk, evaluate_interval, anchor_num, proportion
          ):
    rho1 = 1e-2
    #ρ1
    rho2 = 1e-2
    #ρ2
    rho3 = 1e-3
    #µ1
    rho4 = 1e-3
    #µ2
    gamma = 1e-3
    #γ
    with torch.no_grad():
        data_mo = torch.tensor([]).to(device)
        for data, _, _ in train_dataloader:
            data = data.to(device)
            data_mo = torch.cat((data_mo, data), 0)
            torch.cuda.empty_cache()
        n = data_mo.size(1)
        Y1 = torch.rand(n, code_length).to(device)
        Y2 = torch.rand(n, code_length).to(device)
        B=torch.rand(n,code_length).to(device)
    # Load model
    model = load_model(arch, code_length).to(device)
    # Create criterion, optimizer, scheduler
    criterion = PrototypicalLoss()
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    for it in range(max_iter):
        # timer
        tic = time.time()

        # ADMM use anchors in first step but drop them later
        '''
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)
            for data, _, _ in train_dataloader:
                data = data.to(device)
                output_mo_temp = model_mo(data)
                output_mo = torch.cat((output_mo, output_mo_temp), 0)
                torch.cuda.empty_cache()
            anchor = get_anchor(output_mo, anchor_num, device)  # compute anchor
        '''
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)

            for data, _, _ in train_dataloader:
                output_B, output_A = model(data)
                output_mo = torch.cat((output_mo, output_A), 0)
                torch.cuda.empty_cache()

            dist = euclidean_dist(output_mo, output_mo)
            dist = torch.exp(-1 * dist / torch.max(dist)).to(device)
            A = (2 / (torch.max(dist) - torch.min(dist))) * dist - 1
            global_A=A.numpy()
            Z1 = B + 1 / rho1 * Y1
            Z1[Z1 > 1] = 1
            Z1[Z1 > -1] = -1
            Z2 = B + 1 / rho2 * Y2
            norm_B = torch.norm(Z2)
            Z2 = torch.sqrt(n * code_length) * Z2 / norm_B
            Y1 = Y1 + gamma * rho1 * (B - Z1)
            Y2 = Y2 + gamma * rho2 * (B - Z2)
            global_Z1=Z1.numpy()
            global_Z2=Z2.numpy()
            global_Y1 = Y1.numpy()
            global_Y2 = Y2.numpy()
            B0 = B.numpy()
            B= torch.from_numpy(scipy.optimize.fmin_l_bfgs_b(Baim_func, B0)).to(device)
        # self-supervised deep learning
        model.train()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(device), index.to(device)
            optimizer.zero_grad()

            # output_B for hash code .output_A for result without hash layer
            output_B, output_A= model(data)

            loss = criterion(output_B, B)

            running_loss += loss.item()
            loss.backward()

            optimizer.step()

        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length, device)
            retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                it + 1,
                max_iter,
                running_loss / evaluate_interval,
                mAP,
                training_time,
            ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
Пример #5
0
def train(
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    device,
    topk,
):
    """
    Training model

    Args
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): One-hot query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): One-hot retrieval targets.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        topk(int): Calculate top k data map.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Initialization
    query_data, retrieval_data, query_targets, retrieval_targets = query_data.to(
        device), retrieval_data.to(device), query_targets.to(
            device), retrieval_targets.to(device)

    # Generate random projection matrix
    W = torch.randn(query_data.shape[1], code_length).to(device)

    # Generate query and retrieval code
    query_code = (query_data @ W).sign()
    retrieval_code = (retrieval_data @ W).sign()

    # Compute map
    mAP = mean_average_precision(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
        topk,
    )

    # P-R curve
    P, R = pr_curve(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
    )

    # Save checkpoint
    checkpoint = {
        'qB': query_code,
        'rB': retrieval_code,
        'qL': query_targets,
        'rL': retrieval_targets,
        'W': W,
        'P': P,
        'R': R,
        'map': mAP,
    }
    torch.save(checkpoint,
               'checkpoints/code_{}_map_{:.4f}.pt'.format(code_length, mAP))

    return checkpoint
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr,
          max_iter, topk, evaluate_interval, anchor_num, proportion
          ):
    #print("using device")
    #print(torch.cuda.current_device())
    #print(torch.cuda.get_device_name(torch.cuda.current_device()))
    # Load model
    model = load_model(arch, code_length).to(device)
    model_mo = load_model_mo(arch).to(device)

    # Create criterion, optimizer, scheduler
    criterion = PrototypicalLoss()
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    for it in range(max_iter):
        # timer
        tic = time.time()

        # harvest prototypes/anchors#some times killed, try another way
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)
            for data, _, _ in train_dataloader:
                data = data.to(device)
                output_mo_temp = model_mo(data)
                output_mo = torch.cat((output_mo, output_mo_temp), 0)
                torch.cuda.empty_cache()

            anchor = get_anchor(output_mo, anchor_num, device)  # compute anchor

        # self-supervised deep learning
        model.train()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(device), index.to(device)
            optimizer.zero_grad()

            # output
            output_B = model(data)
            output_mo_batch = model_mo(data)

            # prototypes/anchors based similarity

            #sample_anchor_distance = torch.sqrt(torch.sum((output_mo_batch[:, None, :] - anchor) ** 2, dim=2)).to(device)
            #sample_anchor_dist_normalize = F.normalize(sample_anchor_distance, p=2, dim=1).to(device)
            #S = sample_anchor_dist_normalize @ sample_anchor_dist_normalize.t()

            # loss
            #loss = criterion(output_B, S)
            #running_loss = running_loss + loss.item()
            #loss.backward(retain_graph=True)
            with torch.no_grad():
                dist = torch.sum((output_mo_batch[:, None, :] - anchor.to(device)) ** 2, dim=2)
                k = dist.size(1)
                dist = torch.exp(-1 * dist / torch.max(dist)).to(device)
                Z_su = torch.ones(k, 1).to(device)
                Z_sum = torch.sqrt(dist.mm(Z_su)) + 1e-12
                Z_simi = torch.div(dist, Z_sum).to(device)
                S = (Z_simi.mm(Z_simi.t()))
                S=(2/(torch.max(S)-torch.min(S)))*S-1


            loss = criterion(output_B, S)

            running_loss += loss.item()
            loss.backward()

            optimizer.step()
        with torch.no_grad():
            # momentum update:
            for param_q, param_k in zip(model.parameters(), model_mo.parameters()):
                param_k.data = param_k.data * proportion + param_q.data * (1. - proportion)  # proportion = 0.999 for update

        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length, device)
            retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                it + 1,
                max_iter,
                running_loss / evaluate_interval,
                mAP,
                training_time,
            ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
Пример #7
0
def train(
    train_data,
    query_data,
    query_targets,
    retrieval_data,
    retrieval_targets,
    code_length,
    device,
    topk,
):
    """
    Training model.

    Args
        train_data(torch.Tensor): Training data.
        query_data(torch.Tensor): Query data.
        query_targets(torch.Tensor): Query targets.
        retrieval_data(torch.Tensor): Retrieval data.
        retrieval_targets(torch.Tensor): Retrieval targets.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        topk(int): Calculate top k data points map.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # PCA
    pca = PCA(n_components=code_length)
    X = pca.fit_transform(train_data.numpy())

    # Fit uniform distribution
    eps = np.finfo(float).eps
    mn = X.min(0) - eps
    mx = X.max(0) + eps

    # Enumerate eigenfunctions
    R = mx - mn
    max_mode = np.ceil((code_length + 1) * R / R.max()).astype(np.int)
    n_modes = max_mode.sum() - len(max_mode) + 1
    modes = np.ones([n_modes, code_length])
    m = 0
    for i in range(code_length):
        modes[m + 1:m + max_mode[i], i] = np.arange(1, max_mode[i]) + 1
        m = m + max_mode[i] - 1

    modes -= 1
    omega0 = np.pi / R
    omegas = modes * omega0.reshape(1, -1).repeat(n_modes, 0)
    eig_val = -(omegas**2).sum(1)
    ii = (-eig_val).argsort()
    modes = modes[ii[1:code_length + 1], :]

    # Evaluate
    # Generate query code and retrieval code
    query_code = generate_code(query_data.cpu(), code_length, pca, mn, R,
                               modes).to(device)
    retrieval_code = generate_code(retrieval_data.cpu(), code_length, pca, mn,
                                   R, modes).to(device)
    query_targets = query_targets.to(device)
    retrieval_targets = retrieval_targets.to(device)

    # Compute map
    mAP = mean_average_precision(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
        topk,
    )

    # P-R curve
    P, Recall = pr_curve(
        query_code,
        retrieval_code,
        query_targets,
        retrieval_targets,
        device,
    )

    # Save checkpoint
    checkpoint = {
        'qB': query_code,
        'rB': retrieval_code,
        'qL': query_targets,
        'rL': retrieval_targets,
        'P': P,
        'R': Recall,
        'map': mAP,
    }

    return checkpoint