Beispiel #1
0
def pre_calculate(train_data, class_names, net, args):
    with torch.no_grad():
        all_classes = np.unique(train_data['label'])
        num_classes = len(all_classes)

        # 生成sample类时候的概率矩阵
        train_class_names = {}
        train_class_names['text'] = class_names['text'][all_classes]
        train_class_names['text_len'] = class_names['text_len'][all_classes]
        train_class_names['label'] = class_names['label'][all_classes]
        train_class_names = utils.to_tensor(train_class_names, args.cuda)
        train_class_names_ebd = net.ebd(train_class_names)  # [10, 36, 300]
        train_class_names_ebd = torch.sum(
            train_class_names_ebd, dim=1) / train_class_names['text_len'].view(
                (-1, 1))  # [10, 300]
        dist_metrix = -neg_dist(train_class_names_ebd,
                                train_class_names_ebd)  # [10, 10]

        for i, d in enumerate(dist_metrix):
            if i == 0:
                dist_metrix_nodiag = del_tensor_ele(d, i).view((1, -1))
            else:
                dist_metrix_nodiag = torch.cat(
                    (dist_metrix_nodiag, del_tensor_ele(d, i).view((1, -1))),
                    dim=0)

        prob_metrix = F.softmax(dist_metrix_nodiag, dim=1)  # [10, 9]
        prob_metrix = prob_metrix.cpu().numpy()

        # 生成sample样本时候的概率矩阵
        example_prob_metrix = []
        for i, label in enumerate(all_classes):
            train_examples = {}
            train_examples['text'] = train_data['text'][train_data['label'] ==
                                                        label]
            train_examples['text_len'] = train_data['text_len'][
                train_data['label'] == label]
            train_examples['label'] = train_data['label'][train_data['label']
                                                          == label]
            train_examples = utils.to_tensor(train_examples, args.cuda)
            train_examples_ebd = net.ebd(train_examples)
            train_examples_ebd = torch.sum(
                train_examples_ebd, dim=1) / train_examples['text_len'].view(
                    (-1, 1))  # [N, 300]
            example_prob_metrix_one = -neg_dist(
                train_class_names_ebd[i].view((1, -1)), train_examples_ebd)
            example_prob_metrix_one = F.softmax(example_prob_metrix_one,
                                                dim=1)  # [1, 1000]
            example_prob_metrix_one = example_prob_metrix_one.cpu().numpy()
            example_prob_metrix.append(example_prob_metrix_one)

        return prob_metrix, example_prob_metrix
Beispiel #2
0
def test_one(task, class_names, model, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    # model['G'].eval()
    # model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support'])

    # Embedding the document
    XS = model['G'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    for _ in range(args.test_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 128]

        CN_mlp = model['clf'](CN)  # [N, 256(hidden_size*2)]] -> [N, 128]

        neg_d = neg_dist(XS_mlp, CN_mlp)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    CN_mlp = model['clf'](CN)
    neg_d = neg_dist(XQ_mlp, CN_mlp)

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    return acc_q
Beispiel #3
0
def test_one(task, class_names_dict, model, optG, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)

    '第一步:更新G,让类描述相互离得远一些'
    cn_loss_all = 0
    for _ in range(5):
        CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
        # print("CN:", CN.shape)
        dis = neg_dist(CN, CN)  # [N, N]
        cn_loss = torch.sum(dis) - torch.sum(torch.diag(dis))
        cn_loss_all += cn_loss
        optG.zero_grad()
        cn_loss.backward(retain_graph=True)
        optG.step()
    print('***********[TEST] cn_loss:', cn_loss_all / 5)

    '把CN过微调过的G, S和Q过G2'
    # Embedding the document
    XS = model['G2'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G2'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    '第二步:用Support更新MLP'
    for _ in range(args.test_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 256]

        neg_d = neg_dist(XS_mlp, CN)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    neg_d = neg_dist(XQ_mlp, CN)

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    return acc_q
Beispiel #4
0
def train_one(task, class_names_dict, model, optG, optG2, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    model['G2'].train()
    model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)

    '第一步:更新G,让类描述相互离得远一些'
    cn_loss_all = 0
    for _ in range(5):
        CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
        # print("CN:", CN)
        dis = neg_dist(CN, CN) / torch.mean(neg_dist(CN, CN), dim=0)  # [N, N]
        # print("dis:", dis)
        m = torch.tensor(2.0).cuda()
        cn_loss = torch.tensor(0.0).cuda()
        for i, d in enumerate(dis):
            for j, dd in enumerate(d):
                if i != j:
                    cn_loss = cn_loss + ((torch.max(
                        torch.tensor(0.0).cuda(), m + dd))**2) / 2 / 20
                    print("cn_loss:", cn_loss)
        # cn_loss = cn_loss
        print("********************cn_loss:", cn_loss)
        cn_loss_all += cn_loss
        for name, param in model['G'].named_parameters():
            print("name:", name, "param:", param)
        optG.zero_grad()
        cn_loss.backward()
        optG.step()
    print('***********cn_loss:', cn_loss_all / 5)

    '把CN过微调过的G, S和Q过G2'
    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
    # Embedding the document
    XS = model['G2'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    XQ = model['G2'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)  # 映射标签为从0开始

    '第二步:用Support更新MLP'
    for _ in range(args.train_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 256]

        neg_d = neg_dist(XS_mlp, CN)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    '第三步:用Q更新G2'
    XQ_mlp = model['clf'](XQ)
    neg_d = neg_dist(XQ_mlp, CN)
    q_loss = model['clf'].loss(neg_d, YQ)

    optG2.zero_grad()
    q_loss.backward()
    optG2.step()

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return q_loss, acc_q
def test_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] != class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = support['text'][0].view((1, -1))
    support['text_len_2'] = support['text_len'][0].view(-1)
    support['label_2'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)
    loss = criterion(S_out1, S_out2, support['label_final'])
    zero_grad(model['G'].parameters())
    grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True)
    fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads):
        fast_weights[key] = orderd_params[key] = val - args.task_lr * grad
    '''steps remaining'''
    for k in range(args.train_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)
        loss = criterion(S_out1, S_out2, support['label_final'])
        zero_grad(orderd_params.values())
        grads = torch.autograd.grad(loss,
                                    orderd_params.values(),
                                    allow_unused=True)
        # print('grads:', grads)
        # print("orderd_params.items():", orderd_params.items())
        for (key, val), grad in zip(orderd_params.items(), grads):
            if grad is not None:
                fast_weights[key] = orderd_params[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = neg_dist(XQ, CN)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    return acc_q
def train_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    # model['G2'].train()
    # model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] != class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = support['text'][0].view((1, -1))
    support['text_len_2'] = support['text_len'][0].view(-1)
    support['label_2'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)
    loss = criterion(S_out1, S_out2, support['label_final'])
    zero_grad(model['G'].parameters())
    grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True)
    fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads):
        fast_weights[key] = orderd_params[key] = val - args.task_lr * grad
    '''steps remaining'''
    for k in range(args.train_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)
        loss = criterion(S_out1, S_out2, support['label_final'])
        zero_grad(orderd_params.values())
        grads = torch.autograd.grad(loss,
                                    orderd_params.values(),
                                    allow_unused=True)
        # print('grads:', grads)
        # print("orderd_params.items():", orderd_params.items())
        for (key, val), grad in zip(orderd_params.items(), grads):
            if grad is not None:
                fast_weights[key] = orderd_params[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = neg_dist(XQ, CN)
    q_loss = model['G'].loss(logits_q, YQ)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    optG.zero_grad()
    q_loss.backward()
    optG.step()

    # '把CN过微调过的G, S和Q过G2'
    # CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
    # # Embedding the document
    # XS = model['G2'](support)  # XS:[N*K, 256(hidden_size*2)]
    # # print("XS:", XS.shape)
    # YS = support['label']
    # # print('YS:', YS)
    #
    # XQ = model['G2'](query)
    # YQ = query['label']
    # # print('YQ:', YQ)
    #
    # YS, YQ = reidx_y(args, YS, YQ)  # 映射标签为从0开始
    #
    # '第二步:用Support更新MLP'
    # for _ in range(args.train_iter):
    #
    #     # Embedding the document
    #     XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 256]
    #
    #     neg_d = neg_dist(XS_mlp, CN)  # [N*K, N]
    #     # print("neg_d:", neg_d.shape)
    #
    #     mlp_loss = model['clf'].loss(neg_d, YS)
    #     # print("mlp_loss:", mlp_loss)
    #
    #     optCLF.zero_grad()
    #     mlp_loss.backward(retain_graph=True)
    #     optCLF.step()
    #
    # '第三步:用Q更新G2'
    # XQ_mlp = model['clf'](XQ)
    # neg_d = neg_dist(XQ_mlp, CN)
    # q_loss = model['clf'].loss(neg_d, YQ)

    # optG2.zero_grad()
    # q_loss.backward()
    # optG2.step()
    #
    # _, pred = torch.max(neg_d, 1)
    # acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return q_loss, acc_q
Beispiel #7
0
def train_one(task, class_names, model, optG, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    # Embedding the document
    XS = model['G'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    for _ in range(args.train_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 128]

        CN_mlp = model['clf'](CN)  # [N, 256(hidden_size*2)]] -> [N, 128]

        neg_d = neg_dist(XS_mlp, CN_mlp)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    CN_mlp = model['clf'](CN)
    neg_d = neg_dist(XQ_mlp, CN_mlp)
    g_loss = model['clf'].loss(neg_d, YQ)

    optG.zero_grad()
    g_loss.backward()
    optG.step()

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return g_loss, acc_q