def pretrain(enc, clf, src_loader, params): "enc is the encoder, h is code generator" opt_enc = torch.optim.Adam(enc.parameters(), lr=params.learning_rate) opt_clf = torch.optim.Adam(clf.parameters(), lr=params.learning_rate) loss_records = {"clf":[]} for i in range(params.iterations): loader = enumerate(src_loader) acc_loss = {key:0 for key in loss_records} for step, (images_src, labels_src) in loader: print("epoch {}/ batch {}".format(i,step)) images_src = make_variable(images_src) g_src = enc(images_src) clf_src_out = clf(g_src) # cross-entropy loss criterion = torch.nn.CrossEntropyLoss() labels = torch.LongTensor(labels_src) clf_loss = criterion(clf_src_out, make_variable(labels, requires_grad=False)) acc_loss["clf"] += clf_loss.cpu().data.numpy()[0] clf_loss.backward() opt_enc.step(); opt_clf.step() opt_enc.zero_grad(); opt_clf.zero_grad() # record average loss for key in loss_records.keys(): loss_records[key].append(acc_loss[key] / (step + 1)) models = { "enc": enc, "clf": clf } return { "models": models, "loss_records": loss_records }
def compute_dcd_confusion_loss(outputs, labels): "this loss makes discriminator unable to tell between G1 and G2, and between G3 and G4" G2_outputs = [] # correspond to label 1 G4_outputs = [] # correspond to label 3 # collect DCD outputs of G2, G4 pairs for i, label in enumerate(labels): if (label == 1): G2_outputs.append(outputs[i]) elif (label == 3): G4_outputs.append(outputs[i]) G2_outputs = torch.stack(G2_outputs) G4_outputs = torch.stack(G4_outputs) # calculate cross-entropy loss criterion = torch.nn.CrossEntropyLoss() G1_labels = torch.LongTensor([0 for _ in range(len(G2_outputs))]) G3_labels = torch.LongTensor([2 for _ in range(len(G2_outputs))]) return 0.5 * criterion(G2_outputs,make_variable(G1_labels,requires_grad=False)) + \ 0.5 * criterion(G4_outputs,make_variable(G3_labels,requires_grad=False))
def pretrain(enc, h, src_loader,params): "enc is the encoder, h is code generator" opt_enc = torch.optim.Adam(enc.parameters(), lr=params.learning_rate) opt_h = torch.optim.Adam(h.parameters(), lr=params.learning_rate) loss_records = {"hash":[]} for i in range(params.iterations): loader = enumerate(src_loader) acc_loss = {key:0 for key in loss_records} for step, (images_src, labels_src) in loader: print("epoch {}/ batch {}".format(i,step)) images_src = make_variable(images_src) g_src = enc(images_src) hash_src = h(g_src) hash_loss = get_pairwise_sim_loss(feats=hash_src,labels=labels_src,num_classes=params.num_classes) acc_loss["hash"] += hash_loss.cpu().data.numpy()[0] hash_loss.backward() opt_enc.step(); opt_h.step() opt_enc.zero_grad(); opt_h.zero_grad() # record average loss for key in loss_records.keys(): loss_records[key].append(acc_loss[key] / (step + 1)) models = { "enc": enc, "h": h } return { "models": models, "loss_records": loss_records }
def train(enc, dcd, h, src_loader, tgt_loader, params): "enc is the encoder, DCD is discriminator" opt_enc = torch.optim.Adam(enc.parameters(), lr=params.learning_rate) opt_dcd = torch.optim.Adam(dcd.parameters(), lr=params.learning_rate) opt_h = torch.optim.Adam(h.parameters(), lr=params.learning_rate) loss_records = {"dcd_confusion": [], "hash": [], "dcd_clf": []} for i in range(params.iterations): loader = enumerate(zip(src_loader, tgt_loader)) acc_loss = {key: 0 for key in loss_records} for step, ((images_src, labels_src), (images_tgt, labels_tgt)) in loader: print("epoch {}/ batch {}".format(i, step)) images_src = make_variable(images_src) images_tgt = make_variable(images_tgt) g_src = enc(images_src) g_tgt = enc(images_tgt) # prepare G1, G2, G3, G4 pairs, and pass to discriminator pairs, pair_labels = make_pairs(src_feats=g_src, tgt_feats=g_tgt, src_labels=labels_src, tgt_labels=labels_tgt) dcd_out = dcd(pairs) ###################### # 1. update enc and h# ###################### hash_src = h(g_src) hash_tgt = h(g_tgt) hash_loss = compute_hash_loss(hash=torch.cat([hash_src, hash_tgt]), labels=torch.cat( [labels_src, labels_tgt])) dcd_confusion_loss = compute_dcd_confusion_loss(outputs=dcd_out, labels=pair_labels) combined_loss = params.gamma * dcd_confusion_loss + hash_loss combined_loss.backward(retain_variables=True) opt_enc.step() opt_h.step() opt_enc.zero_grad() opt_h.zero_grad() opt_dcd.zero_grad() # record loss acc_loss["dcd_confusion"] += dcd_confusion_loss.cpu().data.numpy( )[0] acc_loss["hash"] += hash_loss.cpu().data.numpy()[0] ################## # 2. update DCD # ################## dcd_clf_loss = compute_dcd_classification_loss(outputs=dcd_out, labels=pair_labels) dcd_clf_loss.backward() opt_dcd.step() opt_dcd.zero_grad() opt_enc.zero_grad() # record loss acc_loss["dcd_clf"] += dcd_clf_loss.cpu().data.numpy()[0] # record average loss for key in loss_records.keys(): loss_records[key].append(acc_loss[key] / (step + 1)) models = {"enc": enc, "h": h, "dcd": dcd} return {"models": models, "loss_records": loss_records}
def compute_dcd_classification_loss(outputs, labels): criterion = torch.nn.CrossEntropyLoss() labels = torch.LongTensor(labels) return criterion(outputs, make_variable(labels, requires_grad=False))