Esempio n. 1
0
def load_model(code_length, device):
    path = './checkpoints/adsh_nuswide_checkpoints/ADSH_NUSWIDE_48bits.pt'
    M = torch.load(path, map_location=lambda storage, loc: storage)
    B = M['rB']
    model = alexnet.load_model(code_length).to(device)
    model.load_state_dict(M['model'])
    model.eval()
    # load imgrecord
    root = '../datasets/NUS-WIDE'
    img_txt = 'database_img.txt'
    img_txt_path = os.path.join(root, img_txt)
    with open(img_txt_path, 'r') as f:
        imgrecord = np.array([i.strip() for i in f])
    return B, imgrecord, model
Esempio n. 2
0
def load_model(arch, code_length):
    """
    Load CNN model.

    Args
        arch(str): CNN model name.
        code_length(int): Hash code length.

    Returns
        model(torch.nn.Module): CNN model.
    """
    if arch == 'alexnet':
        model = alexnet.load_model(code_length)
    elif arch == 'vgg16':
        model = vgg16.load_model(code_length)
    else:
        raise ValueError('Invalid cnn model name!')

    return model
Esempio n. 3
0
def load_model(name, pretrained=True, num_classes=None):
    """加载模型

    Parameters
        name: str
        模型名称

        pretrained: bool
        True: 加载预训练模型; False: 加载未训练模型

        num_classes: int
        CNN最后一层输出类别

    Returns
        model: model
        模型
    """
    if name == 'alexnet':
        return alexnet.load_model(pretrained=pretrained,
                                  num_classes=num_classes)
Esempio n. 4
0
def train(
    query_dataloader,
    retrieval_dataloader,
    code_length,
    device,
    lr,
    max_iter,
    max_epoch,
    num_samples,
    batch_size,
    root,
    dataset,
    parameters,
    topk,
):
    """
    Training model.

    Args
        query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        mAP(float): Mean Average Precision.
    """

    # parameters = {'eta':2, 'mu':0.2, 'gamma':1, 'varphi':200}
    eta = parameters['eta']
    mu = parameters['mu']
    gamma = parameters['gamma']
    varphi = parameters['varphi']

    # Initialization
    model = alexnet.load_model(code_length).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = DADSH2_Loss(code_length, eta, mu, gamma, varphi, device)

    num_retrieval = len(retrieval_dataloader.dataset)
    U = torch.zeros(num_samples,
                    code_length).to(device)  # U (m*l, l:code_length)
    B = torch.randn(num_retrieval,
                    code_length).to(device)  # V (n*l, l:code_length)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(
        device)  # Y2 (n*c, c:classes)

    start = time.time()
    for it in range(max_iter):
        iter_start = time.time()
        # Sample training data for cnn learning
        train_dataloader, sample_index = sample_dataloader(
            retrieval_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(
            device)  # Y1 (m*c, c:classes)
        S = (train_targets @ retrieval_targets.t() >
             0).float()  # S (m*n, c:classes)
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(
                    device), index.to(device)
                optimizer.zero_grad()

                F = model(data)  # output (m1*l)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], sample_index[index])

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(B.shape).to(device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, varphi)

        # Total loss
        # iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma)

        # logger.debug('[iter:{}/{}][loss:{:.2f}][iter_time:{:.2f}]'.format(it+1, max_iter, iter_loss, time.time()-iter_start))
    logger.info('[Training time:{:.2f}]'.format(time.time() - start))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_targets,
        device,
        topk,
    )

    # Save checkpoints
    torch.save(query_code.cpu(), os.path.join('checkpoints', 'query_code.t'))
    torch.save(B.cpu(), os.path.join('checkpoints', 'database_code.t'))
    torch.save(query_dataloader.dataset.get_onehot_targets,
               os.path.join('checkpoints', 'query_targets.t'))
    torch.save(retrieval_targets.cpu(),
               os.path.join('checkpoints', 'database_targets.t'))
    torch.save(model.cpu(), os.path.join('checkpoints', 'model.t'))

    return mAP
Esempio n. 5
0
def increment(
        query_dataloader,
        unseen_dataloader,
        retrieval_dataloader,
        old_B,
        code_length,
        device,
        lr,
        max_iter,
        max_epoch,
        num_samples,
        batch_size,
        root,
        dataset,
        gamma,
        mu,
        topk,
):
    """
    Increment model.

    Args
        query_dataloader, unseen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        old_B(torch.Tensor): Old binary hash code.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_train(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma, mu(float): Hyper-parameters.
        topk(int): Top k map.

    Returns
        mAP(float): Mean Average Precision.
    """
    # Initialization
    model = alexnet.load_model(code_length)
    model.to(device)
    model.train()
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    criterion = DIHN_Loss(code_length, gamma, mu)
    lr_scheduler = ExponentialLR(optimizer, 0.91)

    num_unseen = len(unseen_dataloader.dataset)
    num_seen = len(old_B)
    U = torch.zeros(num_samples, code_length).to(device)
    old_B = old_B.to(device)
    new_B = torch.randn(num_unseen, code_length).sign().to(device)
    B = torch.cat((old_B, new_B), dim=0).to(device)
    retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets().to(device)

    total_time = time.time()
    for it in range(max_iter):
        iter_time = time.time()
        lr_scheduler.step()

        # Sample training data for cnn learning
        train_dataloader, sample_index, unseen_sample_in_unseen_index, unseen_sample_in_sample_index = sample_dataloader(retrieval_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
        S = (train_targets @ retrieval_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(device), index.to(device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], index)

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(num_unseen, code_length).to(device)
        expand_U[unseen_sample_in_unseen_index, :] = U[unseen_sample_in_sample_index, :]
        new_B = solve_dcc(new_B, U, expand_U, S[:, unseen_dataloader.dataset.UNSEEN_INDEX], code_length, gamma)
        B = torch.cat((old_B, new_B), dim=0).to(device)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma, mu)
        logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time))

    logger.info('[DIHN time:{:.2f}]'.format(time.time() - total_time))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_targets,
        device,
        topk,
    )

    return mAP
Esempio n. 6
0
def train(
        query_dataloader,
        seen_dataloader,
        retrieval_dataloader,
        code_length,
        device,
        lr,
        max_iter,
        max_epoch,
        num_samples,
        batch_size,
        root,
        dataset,
        gamma,
        topk,
):
    """
    Training model.

    Args
        query_dataloader, seen_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        code_length(int): Hashing code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        max_epoch(int): Number of epochs.
        num_samples(int): Number of sampling training data points.
        batch_size(int): Batch size.
        root(str): Path of dataset.
        dataset(str): Dataset name.
        gamma(float): Hyper-parameters.
        topk(int): Topk k map.

    Returns
        None
    """
    # Initialization
    model = alexnet.load_model(code_length).to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    criterion = ADSH_Loss(code_length, gamma)
    lr_scheduler = ExponentialLR(optimizer, 0.9)

    num_seen = len(seen_dataloader.dataset)
    U = torch.zeros(num_samples, code_length).to(device)
    B = torch.randn(num_seen, code_length).sign().to(device)
    seen_targets = seen_dataloader.dataset.get_onehot_targets().to(device)

    total_time = time.time()
    for it in range(max_iter):
        iter_time = time.time()
        lr_scheduler.step()

        # Sample training data for cnn learning
        train_dataloader, sample_index, _, _ = sample_dataloader(seen_dataloader, num_samples, batch_size, root, dataset)

        # Create Similarity matrix
        train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
        S = (train_targets @ seen_targets.t() > 0).float()
        S = torch.where(S == 1, torch.full_like(S, 1), torch.full_like(S, -1))

        # Soft similarity matrix, benefit to converge
        r = S.sum() / (1 - S).sum()
        S = S * (1 + r) - r

        # Training CNN model
        for epoch in range(max_epoch):
            for batch, (data, targets, index) in enumerate(train_dataloader):
                data, targets, index = data.to(device), targets.to(device), index.to(device)
                optimizer.zero_grad()

                F = model(data)
                U[index, :] = F.data
                cnn_loss = criterion(F, B, S[index, :], index)

                cnn_loss.backward()
                optimizer.step()

        # Update B
        expand_U = torch.zeros(B.shape).to(device)
        expand_U[sample_index, :] = U
        B = solve_dcc(B, U, expand_U, S, code_length, gamma)

        # Total loss
        iter_loss = calc_loss(U, B, S, code_length, sample_index, gamma)
        logger.debug('[iter:{}/{}][loss:{:.2f}][time:{:.2f}]'.format(it + 1, max_iter, iter_loss, time.time() - iter_time))

    logger.info('Training adsh finish, time:{:.2f}'.format(time.time()-total_time))

    # Save checkpoints
    torch.save(B.cpu(), os.path.join('checkpoints', 'old_B.t'))

    # Evaluate
    query_code = generate_code(model, query_dataloader, code_length, device)
    mAP = evaluate.mean_average_precision(
        query_code.to(device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(device),
        retrieval_dataloader.dataset.get_onehot_targets().to(device),
        device,
        topk,
    )
    logger.info('[ADSH map:{:.4f}]'.format(mAP))