def pretrain(enc, clf, src_loader, params):
    "enc is the encoder, h is code generator"
    opt_enc = torch.optim.Adam(enc.parameters(), lr=params.learning_rate)
    opt_clf = torch.optim.Adam(clf.parameters(), lr=params.learning_rate)

    loss_records = {"clf":[]}

    for i in range(params.iterations):

        loader = enumerate(src_loader)
        acc_loss = {key:0 for key in loss_records}

        for step, (images_src, labels_src) in loader:

            print("epoch {}/ batch {}".format(i,step))

            images_src = make_variable(images_src)
            g_src = enc(images_src)
            clf_src_out = clf(g_src)

            # cross-entropy loss
            criterion = torch.nn.CrossEntropyLoss()
            labels = torch.LongTensor(labels_src)
            clf_loss =  criterion(clf_src_out, make_variable(labels, requires_grad=False))

            acc_loss["clf"] += clf_loss.cpu().data.numpy()[0]

            clf_loss.backward()
            opt_enc.step(); opt_clf.step()
            opt_enc.zero_grad(); opt_clf.zero_grad()

        # record average loss
        for key in loss_records.keys():
            loss_records[key].append(acc_loss[key] / (step + 1))

    models = {
        "enc": enc,
        "clf": clf
    }
    return {
        "models": models,
        "loss_records": loss_records
    }
def compute_dcd_confusion_loss(outputs, labels):
    "this loss makes discriminator unable to tell between G1 and G2, and between G3 and G4"
    G2_outputs = []  # correspond to label 1
    G4_outputs = []  # correspond to label 3

    # collect DCD outputs of G2, G4 pairs
    for i, label in enumerate(labels):
        if (label == 1):
            G2_outputs.append(outputs[i])
        elif (label == 3):
            G4_outputs.append(outputs[i])

    G2_outputs = torch.stack(G2_outputs)
    G4_outputs = torch.stack(G4_outputs)

    # calculate cross-entropy loss
    criterion = torch.nn.CrossEntropyLoss()
    G1_labels = torch.LongTensor([0 for _ in range(len(G2_outputs))])
    G3_labels = torch.LongTensor([2 for _ in range(len(G2_outputs))])
    return 0.5 * criterion(G2_outputs,make_variable(G1_labels,requires_grad=False)) + \
            0.5 * criterion(G4_outputs,make_variable(G3_labels,requires_grad=False))
Ejemplo n.º 3
0
def pretrain(enc, h, src_loader,params):
    "enc is the encoder, h is code generator"
    opt_enc = torch.optim.Adam(enc.parameters(), lr=params.learning_rate)
    opt_h = torch.optim.Adam(h.parameters(), lr=params.learning_rate)

    loss_records = {"hash":[]}

    for i in range(params.iterations):

        loader = enumerate(src_loader)
        acc_loss = {key:0 for key in loss_records}

        for step, (images_src, labels_src) in loader:

            print("epoch {}/ batch {}".format(i,step))

            images_src = make_variable(images_src)
            g_src = enc(images_src)
            hash_src = h(g_src)
            hash_loss = get_pairwise_sim_loss(feats=hash_src,labels=labels_src,num_classes=params.num_classes)

            acc_loss["hash"] += hash_loss.cpu().data.numpy()[0]

            hash_loss.backward()
            opt_enc.step(); opt_h.step()
            opt_enc.zero_grad(); opt_h.zero_grad()

        # record average loss
        for key in loss_records.keys():
            loss_records[key].append(acc_loss[key] / (step + 1))

    models = {
        "enc": enc,
        "h": h
    }
    return {
        "models": models,
        "loss_records": loss_records
    }
def train(enc, dcd, h, src_loader, tgt_loader, params):
    "enc is the encoder, DCD is discriminator"
    opt_enc = torch.optim.Adam(enc.parameters(), lr=params.learning_rate)
    opt_dcd = torch.optim.Adam(dcd.parameters(), lr=params.learning_rate)
    opt_h = torch.optim.Adam(h.parameters(), lr=params.learning_rate)

    loss_records = {"dcd_confusion": [], "hash": [], "dcd_clf": []}

    for i in range(params.iterations):

        loader = enumerate(zip(src_loader, tgt_loader))
        acc_loss = {key: 0 for key in loss_records}

        for step, ((images_src, labels_src), (images_tgt,
                                              labels_tgt)) in loader:

            print("epoch {}/ batch {}".format(i, step))

            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)

            g_src = enc(images_src)
            g_tgt = enc(images_tgt)

            # prepare G1, G2, G3, G4 pairs, and pass to discriminator
            pairs, pair_labels = make_pairs(src_feats=g_src,
                                            tgt_feats=g_tgt,
                                            src_labels=labels_src,
                                            tgt_labels=labels_tgt)
            dcd_out = dcd(pairs)

            ######################
            # 1. update enc and h#
            ######################
            hash_src = h(g_src)
            hash_tgt = h(g_tgt)
            hash_loss = compute_hash_loss(hash=torch.cat([hash_src, hash_tgt]),
                                          labels=torch.cat(
                                              [labels_src, labels_tgt]))
            dcd_confusion_loss = compute_dcd_confusion_loss(outputs=dcd_out,
                                                            labels=pair_labels)
            combined_loss = params.gamma * dcd_confusion_loss + hash_loss

            combined_loss.backward(retain_variables=True)
            opt_enc.step()
            opt_h.step()
            opt_enc.zero_grad()
            opt_h.zero_grad()
            opt_dcd.zero_grad()

            # record loss
            acc_loss["dcd_confusion"] += dcd_confusion_loss.cpu().data.numpy(
            )[0]
            acc_loss["hash"] += hash_loss.cpu().data.numpy()[0]

            ##################
            # 2. update DCD  #
            ##################
            dcd_clf_loss = compute_dcd_classification_loss(outputs=dcd_out,
                                                           labels=pair_labels)
            dcd_clf_loss.backward()
            opt_dcd.step()
            opt_dcd.zero_grad()
            opt_enc.zero_grad()

            # record loss
            acc_loss["dcd_clf"] += dcd_clf_loss.cpu().data.numpy()[0]

        # record average loss
        for key in loss_records.keys():
            loss_records[key].append(acc_loss[key] / (step + 1))

    models = {"enc": enc, "h": h, "dcd": dcd}
    return {"models": models, "loss_records": loss_records}
def compute_dcd_classification_loss(outputs, labels):
    criterion = torch.nn.CrossEntropyLoss()
    labels = torch.LongTensor(labels)
    return criterion(outputs, make_variable(labels, requires_grad=False))