Esempio n. 1
0
def test_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].eval()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] > class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)
    elif support['text'].shape[1] < class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (support['text'].shape[0],
             class_names_dict['text'].shape[1] - support['text'].shape[1]),
            dtype=torch.long)
        support['text'] = torch.cat((support['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, len(sampled_classes)):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(len(sampled_classes)):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = class_names_dict['text'][0].view((1, -1))
    support['text_len_2'] = class_names_dict['text_len'][0].view(-1)
    support['label_2'] = class_names_dict['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, len(sampled_classes)):
                support['text_2'] = torch.cat(
                    (support['text_2'], class_names_dict['text'][j].view(
                        (1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'],
                     class_names_dict['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'],
                     class_names_dict['label'][j].view(-1)),
                    dim=0)
        else:
            for j in range(len(sampled_classes)):
                support['text_2'] = torch.cat(
                    (support['text_2'], class_names_dict['text'][j].view(
                        (1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'],
                     class_names_dict['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'],
                     class_names_dict['label'][j].view(-1)),
                    dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)

    supp_, que_ = model['G'](support, query)
    loss_weight = get_weight_of_test_support(supp_, que_, args)

    loss = criterion(S_out1, S_out2, support['label_final'], loss_weight)
    # print("s_1_loss:", loss)
    zero_grad(model['G'].parameters())

    grads_fc = autograd.grad(loss,
                             model['G'].fc.parameters(),
                             allow_unused=True,
                             retain_graph=True)
    fast_weights_fc, orderd_params_fc = model['G'].cloned_fc_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads_fc):
        fast_weights_fc[key] = orderd_params_fc[
            key] = val - args.task_lr * grad

    grads_conv11 = autograd.grad(loss,
                                 model['G'].conv11.parameters(),
                                 allow_unused=True,
                                 retain_graph=True)
    fast_weights_conv11, orderd_params_conv11 = model['G'].cloned_conv11_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].conv11.named_parameters(),
                                grads_conv11):
        fast_weights_conv11[key] = orderd_params_conv11[
            key] = val - args.task_lr * grad

    grads_conv12 = autograd.grad(loss,
                                 model['G'].conv12.parameters(),
                                 allow_unused=True,
                                 retain_graph=True)
    fast_weights_conv12, orderd_params_conv12 = model['G'].cloned_conv12_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].conv12.named_parameters(),
                                grads_conv12):
        fast_weights_conv12[key] = orderd_params_conv12[
            key] = val - args.task_lr * grad

    grads_conv13 = autograd.grad(loss,
                                 model['G'].conv13.parameters(),
                                 allow_unused=True)
    fast_weights_conv13, orderd_params_conv13 = model['G'].cloned_conv13_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].conv13.named_parameters(),
                                grads_conv13):
        fast_weights_conv13[key] = orderd_params_conv13[
            key] = val - args.task_lr * grad

    fast_weights = {}
    fast_weights['fc'] = fast_weights_fc
    fast_weights['conv11'] = fast_weights_conv11
    fast_weights['conv12'] = fast_weights_conv12
    fast_weights['conv13'] = fast_weights_conv13
    '''steps remaining'''
    for k in range(args.test_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)

        supp_, que_ = model['G'](support, query, fast_weights)
        loss_weight = get_weight_of_test_support(supp_, que_, args)

        loss = criterion(S_out1, S_out2, support['label_final'], loss_weight)
        # print("train_iter: {} s_loss:{}".format(k, loss))
        zero_grad(orderd_params_fc.values())
        zero_grad(orderd_params_conv11.values())
        zero_grad(orderd_params_conv12.values())
        zero_grad(orderd_params_conv13.values())
        grads_fc = torch.autograd.grad(loss,
                                       orderd_params_fc.values(),
                                       allow_unused=True,
                                       retain_graph=True)
        grads_conv11 = torch.autograd.grad(loss,
                                           orderd_params_conv11.values(),
                                           allow_unused=True,
                                           retain_graph=True)
        grads_conv12 = torch.autograd.grad(loss,
                                           orderd_params_conv12.values(),
                                           allow_unused=True,
                                           retain_graph=True)
        grads_conv13 = torch.autograd.grad(loss,
                                           orderd_params_conv13.values(),
                                           allow_unused=True)

        for (key, val), grad in zip(orderd_params_fc.items(), grads_fc):
            if grad is not None:
                fast_weights['fc'][key] = orderd_params_fc[
                    key] = val - args.task_lr * grad

        for (key, val), grad in zip(orderd_params_conv11.items(),
                                    grads_conv11):
            if grad is not None:
                fast_weights['conv11'][key] = orderd_params_conv11[
                    key] = val - args.task_lr * grad

        for (key, val), grad in zip(orderd_params_conv12.items(),
                                    grads_conv12):
            if grad is not None:
                fast_weights['conv12'][key] = orderd_params_conv12[
                    key] = val - args.task_lr * grad

        for (key, val), grad in zip(orderd_params_conv13.items(),
                                    grads_conv13):
            if grad is not None:
                fast_weights['conv13'][key] = orderd_params_conv13[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = pos_dist(XQ, CN)
    logits_q = dis_to_level(logits_q)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    return acc_q
Esempio n. 2
0
def test_one(task, class_names, model, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    # model['G'].eval()
    # model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support'])

    # Embedding the document
    XS = model['G'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    for _ in range(args.test_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 128]

        CN_mlp = model['clf'](CN)  # [N, 256(hidden_size*2)]] -> [N, 128]

        neg_d = neg_dist(XS_mlp, CN_mlp)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    CN_mlp = model['clf'](CN)
    neg_d = neg_dist(XQ_mlp, CN_mlp)

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    return acc_q
Esempio n. 3
0
def test_one(task, class_names_dict, model, optG, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)

    '第一步:更新G,让类描述相互离得远一些'
    cn_loss_all = 0
    for _ in range(5):
        CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
        # print("CN:", CN.shape)
        dis = neg_dist(CN, CN)  # [N, N]
        cn_loss = torch.sum(dis) - torch.sum(torch.diag(dis))
        cn_loss_all += cn_loss
        optG.zero_grad()
        cn_loss.backward(retain_graph=True)
        optG.step()
    print('***********[TEST] cn_loss:', cn_loss_all / 5)

    '把CN过微调过的G, S和Q过G2'
    # Embedding the document
    XS = model['G2'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G2'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    '第二步:用Support更新MLP'
    for _ in range(args.test_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 256]

        neg_d = neg_dist(XS_mlp, CN)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    neg_d = neg_dist(XQ_mlp, CN)

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    return acc_q
Esempio n. 4
0
def train_one(task, class_names_dict, model, optG, optG2, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    model['G2'].train()
    model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)

    '第一步:更新G,让类描述相互离得远一些'
    cn_loss_all = 0
    for _ in range(5):
        CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
        # print("CN:", CN)
        dis = neg_dist(CN, CN) / torch.mean(neg_dist(CN, CN), dim=0)  # [N, N]
        # print("dis:", dis)
        m = torch.tensor(2.0).cuda()
        cn_loss = torch.tensor(0.0).cuda()
        for i, d in enumerate(dis):
            for j, dd in enumerate(d):
                if i != j:
                    cn_loss = cn_loss + ((torch.max(
                        torch.tensor(0.0).cuda(), m + dd))**2) / 2 / 20
                    print("cn_loss:", cn_loss)
        # cn_loss = cn_loss
        print("********************cn_loss:", cn_loss)
        cn_loss_all += cn_loss
        for name, param in model['G'].named_parameters():
            print("name:", name, "param:", param)
        optG.zero_grad()
        cn_loss.backward()
        optG.step()
    print('***********cn_loss:', cn_loss_all / 5)

    '把CN过微调过的G, S和Q过G2'
    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
    # Embedding the document
    XS = model['G2'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    XQ = model['G2'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)  # 映射标签为从0开始

    '第二步:用Support更新MLP'
    for _ in range(args.train_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 256]

        neg_d = neg_dist(XS_mlp, CN)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    '第三步:用Q更新G2'
    XQ_mlp = model['clf'](XQ)
    neg_d = neg_dist(XQ_mlp, CN)
    q_loss = model['clf'].loss(neg_d, YQ)

    optG2.zero_grad()
    q_loss.backward()
    optG2.step()

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return q_loss, acc_q
def test_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] != class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = support['text'][0].view((1, -1))
    support['text_len_2'] = support['text_len'][0].view(-1)
    support['label_2'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)
    loss = criterion(S_out1, S_out2, support['label_final'])
    zero_grad(model['G'].parameters())
    grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True)
    fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads):
        fast_weights[key] = orderd_params[key] = val - args.task_lr * grad
    '''steps remaining'''
    for k in range(args.train_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)
        loss = criterion(S_out1, S_out2, support['label_final'])
        zero_grad(orderd_params.values())
        grads = torch.autograd.grad(loss,
                                    orderd_params.values(),
                                    allow_unused=True)
        # print('grads:', grads)
        # print("orderd_params.items():", orderd_params.items())
        for (key, val), grad in zip(orderd_params.items(), grads):
            if grad is not None:
                fast_weights[key] = orderd_params[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = neg_dist(XQ, CN)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    return acc_q
def train_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    # model['G2'].train()
    # model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] != class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = support['text'][0].view((1, -1))
    support['text_len_2'] = support['text_len'][0].view(-1)
    support['label_2'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)
    loss = criterion(S_out1, S_out2, support['label_final'])
    zero_grad(model['G'].parameters())
    grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True)
    fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads):
        fast_weights[key] = orderd_params[key] = val - args.task_lr * grad
    '''steps remaining'''
    for k in range(args.train_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)
        loss = criterion(S_out1, S_out2, support['label_final'])
        zero_grad(orderd_params.values())
        grads = torch.autograd.grad(loss,
                                    orderd_params.values(),
                                    allow_unused=True)
        # print('grads:', grads)
        # print("orderd_params.items():", orderd_params.items())
        for (key, val), grad in zip(orderd_params.items(), grads):
            if grad is not None:
                fast_weights[key] = orderd_params[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = neg_dist(XQ, CN)
    q_loss = model['G'].loss(logits_q, YQ)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    optG.zero_grad()
    q_loss.backward()
    optG.step()

    # '把CN过微调过的G, S和Q过G2'
    # CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
    # # Embedding the document
    # XS = model['G2'](support)  # XS:[N*K, 256(hidden_size*2)]
    # # print("XS:", XS.shape)
    # YS = support['label']
    # # print('YS:', YS)
    #
    # XQ = model['G2'](query)
    # YQ = query['label']
    # # print('YQ:', YQ)
    #
    # YS, YQ = reidx_y(args, YS, YQ)  # 映射标签为从0开始
    #
    # '第二步:用Support更新MLP'
    # for _ in range(args.train_iter):
    #
    #     # Embedding the document
    #     XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 256]
    #
    #     neg_d = neg_dist(XS_mlp, CN)  # [N*K, N]
    #     # print("neg_d:", neg_d.shape)
    #
    #     mlp_loss = model['clf'].loss(neg_d, YS)
    #     # print("mlp_loss:", mlp_loss)
    #
    #     optCLF.zero_grad()
    #     mlp_loss.backward(retain_graph=True)
    #     optCLF.step()
    #
    # '第三步:用Q更新G2'
    # XQ_mlp = model['clf'](XQ)
    # neg_d = neg_dist(XQ_mlp, CN)
    # q_loss = model['clf'].loss(neg_d, YQ)

    # optG2.zero_grad()
    # q_loss.backward()
    # optG2.step()
    #
    # _, pred = torch.max(neg_d, 1)
    # acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return q_loss, acc_q
Esempio n. 7
0
def train_one(task, class_names, model, optG, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    # Embedding the document
    XS = model['G'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    for _ in range(args.train_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 128]

        CN_mlp = model['clf'](CN)  # [N, 256(hidden_size*2)]] -> [N, 128]

        neg_d = neg_dist(XS_mlp, CN_mlp)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    CN_mlp = model['clf'](CN)
    neg_d = neg_dist(XQ_mlp, CN_mlp)
    g_loss = model['clf'].loss(neg_d, YQ)

    optG.zero_grad()
    g_loss.backward()
    optG.step()

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return g_loss, acc_q