Beispiel #1
0
    def get_epoch(self):
        for _ in range(self.num_episodes):
            # wait until self.thread finishes
            support, query = self.done_queue.get()

            # convert to torch.tensor
            support = utils.to_tensor(support, self.args.cuda, ['raw'])
            query = utils.to_tensor(query, self.args.cuda, ['raw'])

            if 'bert_id' in support.keys():
                # run bert to get ebd
                support['ebd'] = self.get_bert(
                        support['bert_id'],
                        support['text_len']+2)
                query['ebd'] = self.get_bert(
                        query['bert_id'],
                        query['text_len']+2)

            if self.args.meta_w_target:
                if self.args.meta_target_entropy:
                    w = stats.get_w_target(
                            support, self.data['vocab_size'],
                        self.data['avg_ebd'], self.args.meta_w_target_lam)
                else:  # use rr approxmation (this one is faster)
                    w = stats.get_w_target_rr(
                            support, self.data['vocab_size'],
                        self.data['avg_ebd'], self.args.meta_w_target_lam)
                support['w_target'] = w.detach()
                query['w_target'] = w.detach()

            support['is_support'] = True
            query['is_support'] = False

            yield support, query
Beispiel #2
0
    def __getitem__(self, item):
        video = self.videos[item]
        metadata_file = os.path.join(self.root_dir, video, 'metadata.pkl')

        with open(metadata_file, 'rb') as f_in:
            frame_list = pickle.load(f_in)

        frame_count = len(frame_list)
        remains = self.return_count
        sampled_frames = list()

        while remains > frame_count:
            sampled_frames.extend(range(frame_count))
            remains -= frame_count

        sampled_frames.extend(random.sample(range(frame_count), remains))

        # sanity check
        assert len(sampled_frames) == self.return_count

        x = list()
        y = list()

        for i in sampled_frames:
            f, landmarks = frame_list[i]
            full_path = os.path.join(self.root_dir, video, f)

            img = Image.open(full_path).convert('RGB')
            img = img.resize((self.image_size, self.image_size), Image.LANCZOS)

            if self.random_flip:
                indicator = random.random()
                if indicator > 0.5:
                    # flip
                    img = img.transpose(Image.FLIP_LEFT_RIGHT)
                    landmarks[:, 0] = self.image_size - 1 - landmarks[:, 0]

            x.append(to_tensor(img, self.normalize))

            rendered = plot_landmarks(self.image_size, landmarks)
            y.append(to_tensor(rendered, self.normalize))

            # debug
            if self._debug:
                img.save('%d.jpg' % i)
                rendered.save('%d_lm.jpg' % i)
                debug_img = plot_landmarks(self.image_size,
                                           landmarks,
                                           original_image=img)
                debug_img.save('%d_bg.jpg' % i)

        x_t = x[0]
        y_t = y[0]

        x = torch.stack(x[1:])  # return_count * c * h * w
        y = torch.stack(y[1:])

        return item, x, y, x_t, y_t
Beispiel #3
0
def pre_calculate(train_data, class_names, net, args):
    with torch.no_grad():
        all_classes = np.unique(train_data['label'])
        num_classes = len(all_classes)

        # 生成sample类时候的概率矩阵
        train_class_names = {}
        train_class_names['text'] = class_names['text'][all_classes]
        train_class_names['text_len'] = class_names['text_len'][all_classes]
        train_class_names['label'] = class_names['label'][all_classes]
        train_class_names = utils.to_tensor(train_class_names, args.cuda)
        train_class_names_ebd = net.ebd(train_class_names)  # [10, 36, 300]
        train_class_names_ebd = torch.sum(
            train_class_names_ebd, dim=1) / train_class_names['text_len'].view(
                (-1, 1))  # [10, 300]
        dist_metrix = -neg_dist(train_class_names_ebd,
                                train_class_names_ebd)  # [10, 10]

        for i, d in enumerate(dist_metrix):
            if i == 0:
                dist_metrix_nodiag = del_tensor_ele(d, i).view((1, -1))
            else:
                dist_metrix_nodiag = torch.cat(
                    (dist_metrix_nodiag, del_tensor_ele(d, i).view((1, -1))),
                    dim=0)

        prob_metrix = F.softmax(dist_metrix_nodiag, dim=1)  # [10, 9]
        prob_metrix = prob_metrix.cpu().numpy()

        # 生成sample样本时候的概率矩阵
        example_prob_metrix = []
        for i, label in enumerate(all_classes):
            train_examples = {}
            train_examples['text'] = train_data['text'][train_data['label'] ==
                                                        label]
            train_examples['text_len'] = train_data['text_len'][
                train_data['label'] == label]
            train_examples['label'] = train_data['label'][train_data['label']
                                                          == label]
            train_examples = utils.to_tensor(train_examples, args.cuda)
            train_examples_ebd = net.ebd(train_examples)
            train_examples_ebd = torch.sum(
                train_examples_ebd, dim=1) / train_examples['text_len'].view(
                    (-1, 1))  # [N, 300]
            example_prob_metrix_one = -neg_dist(
                train_class_names_ebd[i].view((1, -1)), train_examples_ebd)
            example_prob_metrix_one = F.softmax(example_prob_metrix_one,
                                                dim=1)  # [1, 1000]
            example_prob_metrix_one = example_prob_metrix_one.cpu().numpy()
            example_prob_metrix.append(example_prob_metrix_one)

        return prob_metrix, example_prob_metrix
Beispiel #4
0
    def get_epoch(self):
        for _ in range(self.num_episodes):
            # wait until self.thread finishes
            support, query = self.done_queue.get()

            # convert to torch.tensor
            support = utils.to_tensor(support, self.args.cuda, ['raw'])
            query = utils.to_tensor(query, self.args.cuda, ['raw'])

            support['is_support'] = True
            query['is_support'] = False

            yield support, query
Beispiel #5
0
def test_one(task, class_names, model, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    # model['G'].eval()
    # model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support'])

    # Embedding the document
    XS = model['G'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    for _ in range(args.test_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 128]

        CN_mlp = model['clf'](CN)  # [N, 256(hidden_size*2)]] -> [N, 128]

        neg_d = neg_dist(XS_mlp, CN_mlp)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    CN_mlp = model['clf'](CN)
    neg_d = neg_dist(XQ_mlp, CN_mlp)

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    return acc_q
Beispiel #6
0
def main(model_file):
    run_id = datetime.now().strftime('%Y%m%d_%H%M_finetune')

    output_path = os.path.join('output', run_id)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    print('The ID of this run: ' + run_id)
    print('Output directory: ' + output_path)

    all_people = os.listdir(config.test_dataset)
    people_count = len(all_people)
    print('People count: %d' % people_count)

    for i, person in enumerate(all_people):
        print('Progress: %d/%d' % (i, people_count))

        # T training images should come from the same video
        xx = sample_frames(person, config.finetune_T)

        person_t = person
        while person_t == person:
            (person_t, ) = random.sample(all_people, 1)

        xx_t = sample_frames(person_t, 1)

        xx_all = xx_t + xx

        x = list()
        y = list()

        detector = get_detector('cuda')

        for filename in xx_all:
            img = Image.open(filename).convert('RGB')
            img = img.resize((config.input_size, config.input_size),
                             Image.LANCZOS)
            x.append(to_tensor(img, config.input_normalize))

            arr = np.array(img)
            landmarks = extract_landmark(detector, arr)

            rendered = plot_landmarks(config.input_size, landmarks)
            y.append(to_tensor(rendered, config.input_normalize))

        del detector
        torch.set_grad_enabled(True)

        x_t = torch.unsqueeze(x[0], dim=0)
        y_t = torch.unsqueeze(y[0], dim=0)
        y_t = y_t.cuda()

        x = torch.stack(x[1:])  # n * c * h * w
        y = torch.stack(y[1:])

        # sanity check
        assert x.size(0) == config.finetune_T

        # load models
        save_data = torch.load(model_file)
        _, _, _, G_state_dict, E_state_dict, D_state_dict = save_data[:6]

        G = Generator(config.G_config, config.input_normalize)
        G = G.eval()
        G = G.cuda()

        E = Embedder(config.E_config, config.embedding_dim)
        E = E.eval()
        E = E.cuda()

        D = Discriminator(config.V_config, config.embedding_dim)
        D = D.eval()
        D = D.cuda()

        with torch.no_grad():
            E.load_state_dict(E_state_dict)
            E_input = torch.cat((x, y), dim=1)
            E_input = E_input.cuda()
            e_hat = E(E_input)
            e_hat = e_hat.view(1, -1, config.embedding_dim)
            e_hat_mean = torch.mean(e_hat, dim=1, keepdim=False)
            del E

            P = G_state_dict['P.weight']
            adain = torch.matmul(e_hat_mean, torch.transpose(P, 0, 1))
            del G_state_dict['P.weight']
            adain = adain.view(1, -1, 2)
            assert adain.size(1) == G.adain_param_count
            G_state_dict['adain'] = adain.data
            G.load_state_dict(G_state_dict)

            del D_state_dict['embedding.weight']
            w0 = D_state_dict['w0']
            w = w0 + e_hat_mean
            del D_state_dict['w0']
            D_state_dict['w'] = w.data
            D.load_state_dict(D_state_dict)

            x_hat_0 = G(y_t)
            x_hat_0_img = to_pil_image(x_hat_0, config.input_normalize)
            del x_hat_0

        G = G.train()
        set_grad_enabled(G, True)
        D = D.train()
        set_grad_enabled(D, True)

        # loss
        L_EG = Loss_EG_finetune(config.vgg19_layers, config.vggface_layers,
                                config.vgg19_weight_file,
                                config.vggface_weight_file,
                                config.vgg19_loss_weight,
                                config.vggface_loss_weight,
                                config.fm_loss_weight, config.input_normalize)

        L_EG = L_EG.eval()
        L_EG = L_EG.cuda()
        set_grad_enabled(L_EG, False)

        optim_EG = optim.Adam(G.parameters(),
                              lr=config.lr_EG,
                              betas=config.adam_betas)
        optim_D = optim.Adam(D.parameters(),
                             lr=config.lr_D,
                             betas=config.adam_betas)

        # dataset
        dataset = TensorDataset(x, y)
        dataloader = DataLoader(dataset,
                                batch_size=config.finetune_batch_size,
                                shuffle=config.dataset_shuffle,
                                num_workers=config.num_worker,
                                pin_memory=True,
                                drop_last=False)

        # finetune
        for epoch in range(config.finetune_epoch):
            for _, (xx, yy) in enumerate(dataloader):
                xx = xx.cuda()
                yy = yy.cuda()

                optim_EG.zero_grad()
                optim_D.zero_grad()

                x_hat = G(yy)

                d_output = D(torch.cat((xx, yy), dim=1))
                d_output_hat = D(torch.cat((x_hat, yy), dim=1))

                d_features = d_output[:-1]
                d_features_hat = d_output_hat[:-1]
                d_score = d_output[-1]
                d_score_hat = d_output_hat[-1]

                l_eg, l_vgg19, l_vggface, l_cnt, l_adv, l_fm = \
                    L_EG(xx, x_hat, d_features, d_features_hat, d_score_hat)

                l_d = Loss_DSC(d_score_hat, d_score)
                loss = l_eg + l_d
                loss.backward()
                optim_EG.step()
                optim_D.step()

                # train D again
                optim_D.zero_grad()
                x_hat = x_hat.detach()  # do not need to train the generator

                d_output = D(torch.cat((xx, yy), dim=1))
                d_output_hat = D(torch.cat((x_hat, yy), dim=1))

                d_score = d_output[-1]
                d_score_hat = d_output_hat[-1]

                l_d2 = Loss_DSC(d_score_hat, d_score)
                l_d2.backward()
                optim_D.step()

        # after finetuning
        with torch.no_grad():
            x_hat_1 = G(y_t)
            x_hat_1_img = to_pil_image(x_hat_1, config.input_normalize)
            del x_hat_1

        # save image
        training_img = Image.new(
            'RGB', (config.finetune_T * config.input_size, config.input_size))
        for j in range(config.metatrain_T):
            img = to_pil_image(x[j], config.input_normalize)
            training_img.paste(img, (j * config.input_size, 0))

        training_img.save(os.path.join(output_path, 't_%d.jpg' % i))

        x_t_img = to_pil_image(x_t, config.input_normalize)
        y_t_img = to_pil_image(y_t, config.input_normalize)

        output_img = Image.new('RGB',
                               (4 * config.input_size, config.input_size))
        output_img.paste(x_hat_0_img, (0, 0))
        output_img.paste(x_hat_1_img, (config.input_size, 0))
        output_img.paste(x_t_img, (2 * config.input_size, 0))
        output_img.paste(y_t_img, (3 * config.input_size, 0))

        output_img.save(os.path.join(output_path, 'o_%d.jpg' % i))
Beispiel #7
0
def test_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].eval()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] > class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)
    elif support['text'].shape[1] < class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (support['text'].shape[0],
             class_names_dict['text'].shape[1] - support['text'].shape[1]),
            dtype=torch.long)
        support['text'] = torch.cat((support['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, len(sampled_classes)):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(len(sampled_classes)):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = class_names_dict['text'][0].view((1, -1))
    support['text_len_2'] = class_names_dict['text_len'][0].view(-1)
    support['label_2'] = class_names_dict['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, len(sampled_classes)):
                support['text_2'] = torch.cat(
                    (support['text_2'], class_names_dict['text'][j].view(
                        (1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'],
                     class_names_dict['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'],
                     class_names_dict['label'][j].view(-1)),
                    dim=0)
        else:
            for j in range(len(sampled_classes)):
                support['text_2'] = torch.cat(
                    (support['text_2'], class_names_dict['text'][j].view(
                        (1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'],
                     class_names_dict['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'],
                     class_names_dict['label'][j].view(-1)),
                    dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)

    supp_, que_ = model['G'](support, query)
    loss_weight = get_weight_of_test_support(supp_, que_, args)

    loss = criterion(S_out1, S_out2, support['label_final'], loss_weight)
    # print("s_1_loss:", loss)
    zero_grad(model['G'].parameters())

    grads_fc = autograd.grad(loss,
                             model['G'].fc.parameters(),
                             allow_unused=True,
                             retain_graph=True)
    fast_weights_fc, orderd_params_fc = model['G'].cloned_fc_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads_fc):
        fast_weights_fc[key] = orderd_params_fc[
            key] = val - args.task_lr * grad

    grads_conv11 = autograd.grad(loss,
                                 model['G'].conv11.parameters(),
                                 allow_unused=True,
                                 retain_graph=True)
    fast_weights_conv11, orderd_params_conv11 = model['G'].cloned_conv11_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].conv11.named_parameters(),
                                grads_conv11):
        fast_weights_conv11[key] = orderd_params_conv11[
            key] = val - args.task_lr * grad

    grads_conv12 = autograd.grad(loss,
                                 model['G'].conv12.parameters(),
                                 allow_unused=True,
                                 retain_graph=True)
    fast_weights_conv12, orderd_params_conv12 = model['G'].cloned_conv12_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].conv12.named_parameters(),
                                grads_conv12):
        fast_weights_conv12[key] = orderd_params_conv12[
            key] = val - args.task_lr * grad

    grads_conv13 = autograd.grad(loss,
                                 model['G'].conv13.parameters(),
                                 allow_unused=True)
    fast_weights_conv13, orderd_params_conv13 = model['G'].cloned_conv13_dict(
    ), OrderedDict()
    for (key, val), grad in zip(model['G'].conv13.named_parameters(),
                                grads_conv13):
        fast_weights_conv13[key] = orderd_params_conv13[
            key] = val - args.task_lr * grad

    fast_weights = {}
    fast_weights['fc'] = fast_weights_fc
    fast_weights['conv11'] = fast_weights_conv11
    fast_weights['conv12'] = fast_weights_conv12
    fast_weights['conv13'] = fast_weights_conv13
    '''steps remaining'''
    for k in range(args.test_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)

        supp_, que_ = model['G'](support, query, fast_weights)
        loss_weight = get_weight_of_test_support(supp_, que_, args)

        loss = criterion(S_out1, S_out2, support['label_final'], loss_weight)
        # print("train_iter: {} s_loss:{}".format(k, loss))
        zero_grad(orderd_params_fc.values())
        zero_grad(orderd_params_conv11.values())
        zero_grad(orderd_params_conv12.values())
        zero_grad(orderd_params_conv13.values())
        grads_fc = torch.autograd.grad(loss,
                                       orderd_params_fc.values(),
                                       allow_unused=True,
                                       retain_graph=True)
        grads_conv11 = torch.autograd.grad(loss,
                                           orderd_params_conv11.values(),
                                           allow_unused=True,
                                           retain_graph=True)
        grads_conv12 = torch.autograd.grad(loss,
                                           orderd_params_conv12.values(),
                                           allow_unused=True,
                                           retain_graph=True)
        grads_conv13 = torch.autograd.grad(loss,
                                           orderd_params_conv13.values(),
                                           allow_unused=True)

        for (key, val), grad in zip(orderd_params_fc.items(), grads_fc):
            if grad is not None:
                fast_weights['fc'][key] = orderd_params_fc[
                    key] = val - args.task_lr * grad

        for (key, val), grad in zip(orderd_params_conv11.items(),
                                    grads_conv11):
            if grad is not None:
                fast_weights['conv11'][key] = orderd_params_conv11[
                    key] = val - args.task_lr * grad

        for (key, val), grad in zip(orderd_params_conv12.items(),
                                    grads_conv12):
            if grad is not None:
                fast_weights['conv12'][key] = orderd_params_conv12[
                    key] = val - args.task_lr * grad

        for (key, val), grad in zip(orderd_params_conv13.items(),
                                    grads_conv13):
            if grad is not None:
                fast_weights['conv13'][key] = orderd_params_conv13[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = pos_dist(XQ, CN)
    logits_q = dis_to_level(logits_q)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    return acc_q
Beispiel #8
0
def test(test_data,
         class_names,
         optG,
         optCLF,
         model,
         args,
         num_episodes,
         verbose=True):
    '''
        Evaluate the model on a bag of sampled tasks. Return the mean accuracy
        and its std.
    '''
    model['G'].train()
    model['G2'].train()
    model['clf'].train()

    acc = []
    for ep in range(num_episodes):
        # if args.embedding == 'mlada':
        #     acc1, d_acc1, sentence_ebd, avg_sentence_ebd, sentence_label, word_weight, query_data, x_hat = test_one(task, model, args)
        #     if count < 20:
        #         if all_sentence_ebd is None:
        #             all_sentence_ebd = sentence_ebd
        #             all_avg_sentence_ebd = avg_sentence_ebd
        #             all_sentence_label = sentence_label
        #             all_word_weight = word_weight
        #             all_query_data = query_data
        #             all_x_hat = x_hat
        #         else:
        #             all_sentence_ebd = np.concatenate((all_sentence_ebd, sentence_ebd), 0)
        #             all_avg_sentence_ebd = np.concatenate((all_avg_sentence_ebd, avg_sentence_ebd), 0)
        #             all_sentence_label = np.concatenate((all_sentence_label, sentence_label))
        #             all_word_weight = np.concatenate((all_word_weight, word_weight), 0)
        #             all_query_data = np.concatenate((all_query_data, query_data), 0)
        #             all_x_hat = np.concatenate((all_x_hat, x_hat), 0)
        #     count = count + 1
        #     acc.append(acc1)
        #     d_acc.append(d_acc1)
        # else:
        #     acc.append(test_one(task, model, args))
        sampled_classes, source_classes = task_sampler(test_data, args)
        class_names_dict = {}
        class_names_dict['label'] = class_names['label'][sampled_classes]
        class_names_dict['text'] = class_names['text'][sampled_classes]
        class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
        class_names_dict['is_support'] = False

        train_gen = ParallelSampler(test_data, args, sampled_classes,
                                    source_classes, args.train_episodes)

        sampled_tasks = train_gen.get_epoch()
        class_names_dict = utils.to_tensor(class_names_dict,
                                           args.cuda,
                                           exclude_keys=['is_support'])

        grad = {'clf': [], 'G': []}

        if not args.notqdm:
            sampled_tasks = tqdm(sampled_tasks,
                                 total=train_gen.num_episodes,
                                 ncols=80,
                                 leave=False,
                                 desc=colored('Training on train', 'yellow'))

        for task in sampled_tasks:
            if task is None:
                break
            q_acc = test_one(task, class_names_dict, model, optG, optCLF, args,
                             grad)
            acc.append(q_acc.cpu().item())

    acc = np.array(acc)

    if verbose:
        if args.embedding != 'mlada':
            print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now(),
                colored("test acc mean", "blue"),
                np.mean(acc),
                colored("test std", "blue"),
                np.std(acc),
            ),
                  flush=True)
        else:
            print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now(),
                colored("test acc mean", "blue"),
                np.mean(acc),
                colored("test std", "blue"),
                np.std(acc),
            ),
                  flush=True)

    return np.mean(acc), np.std(acc)
Beispiel #9
0
def train(train_data, val_data, model, class_names, args):
    '''
        Train the model
        Use val_data to do early stopping
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(
        os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    optG = torch.optim.Adam(grad_param(model, ['G']), lr=args.meta_lr)
    optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr)
    optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr)

    if args.lr_scheduler == 'ReduceLROnPlateau':
        schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True)
        schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optCLF,
            'max',
            patience=args.patience // 2,
            factor=0.1,
            verbose=True)

    elif args.lr_scheduler == 'ExponentialLR':
        schedulerG = torch.optim.lr_scheduler.ExponentialLR(
            optG, gamma=args.ExponentialLR_gamma)
        schedulerCLF = torch.optim.lr_scheduler.ExponentialLR(
            optCLF, gamma=args.ExponentialLR_gamma)

    print("{}, Start training".format(datetime.datetime.now()), flush=True)

    # train_gen = ParallelSampler(train_data, args, args.train_episodes)
    # train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes)
    # val_gen = ParallelSampler_Test(val_data, args, args.val_episodes)

    # sampled_classes, source_classes = task_sampler(train_data, args)
    acc = 0
    loss = 0
    for ep in range(args.train_epochs):

        sampled_classes, source_classes = task_sampler(train_data, args)
        class_names_dict = {}
        class_names_dict['label'] = class_names['label'][sampled_classes]
        class_names_dict['text'] = class_names['text'][sampled_classes]
        class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
        class_names_dict['is_support'] = False

        train_gen = ParallelSampler(train_data, args, sampled_classes,
                                    source_classes, args.train_episodes)

        sampled_tasks = train_gen.get_epoch()
        class_names_dict = utils.to_tensor(class_names_dict,
                                           args.cuda,
                                           exclude_keys=['is_support'])

        grad = {'clf': [], 'G': []}

        if not args.notqdm:
            sampled_tasks = tqdm(sampled_tasks,
                                 total=train_gen.num_episodes,
                                 ncols=80,
                                 leave=False,
                                 desc=colored('Training on train', 'yellow'))

        for task in sampled_tasks:
            if task is None:
                break
            q_loss, q_acc = train_one(task, class_names_dict, model, optG,
                                      optG2, optCLF, args, grad)
            acc += q_acc
            loss += q_loss

        if ep % 100 == 0:
            print("--------[TRAIN] ep:" + str(ep) + ", loss:" +
                  str(q_loss.item()) + ", acc:" + str(q_acc.item()) +
                  "-----------")

        if (ep % 200 == 0) and (ep != 0):
            acc = acc / args.train_episodes / 200
            loss = loss / args.train_episodes / 200
            print("--------[TRAIN] ep:" + str(ep) + ", mean_loss:" +
                  str(loss.item()) + ", mean_acc:" + str(acc.item()) +
                  "-----------")

            net = copy.deepcopy(model)
            acc, std = test(train_data, class_names, optG, optCLF, net, args,
                            args.test_epochs, False)
            print(
                "[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                    datetime.datetime.now(),
                    "ep",
                    ep,
                    colored("train", "red"),
                    colored("acc:", "blue"),
                    acc,
                    std,
                ),
                flush=True)
            acc = 0
            loss = 0

            # Evaluate validation accuracy
            cur_acc, cur_std = test(val_data, class_names, optG, optCLF, net,
                                    args, args.test_epochs, False)
            print(("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "
                   "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
                       datetime.datetime.now(),
                       "ep",
                       ep,
                       colored("val  ", "cyan"),
                       colored("acc:", "blue"),
                       cur_acc,
                       cur_std,
                       colored("train stats", "cyan"),
                       colored("G_grad:", "blue"),
                       np.mean(np.array(grad['G'])),
                       colored("clf_grad:", "blue"),
                       np.mean(np.array(grad['clf'])),
                   ),
                  flush=True)

            # Update the current best model if val acc is better
            if cur_acc > best_acc:
                best_acc = cur_acc
                best_path = os.path.join(out_dir, str(ep))

                # save current model
                print("{}, Save cur best model to {}".format(
                    datetime.datetime.now(), best_path))

                torch.save(model['G'].state_dict(), best_path + '.G')
                torch.save(model['G2'].state_dict(), best_path + '.G2')
                torch.save(model['clf'].state_dict(), best_path + '.clf')

                sub_cycle = 0
            else:
                sub_cycle += 1

            # Break if the val acc hasn't improved in the past patience epochs
            if sub_cycle == args.patience:
                break

            if args.lr_scheduler == 'ReduceLROnPlateau':
                schedulerG.step(cur_acc)
                schedulerCLF.step(cur_acc)

            elif args.lr_scheduler == 'ExponentialLR':
                schedulerG.step()
                schedulerCLF.step()

    print("{}, End of training. Restore the best weights".format(
        datetime.datetime.now()),
          flush=True)

    # restore the best saved model
    model['G'].load_state_dict(torch.load(best_path + '.G'))
    model['G2'].load_state_dict(torch.load(best_path + '.G2'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "saved-runs",
                         str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(datetime.datetime.now(),
                                                 best_path),
              flush=True)

        torch.save(model['G'].state_dict(), best_path + '.G')
        torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return optG, optCLF
def test_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] != class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = support['text'][0].view((1, -1))
    support['text_len_2'] = support['text_len'][0].view(-1)
    support['label_2'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)
    loss = criterion(S_out1, S_out2, support['label_final'])
    zero_grad(model['G'].parameters())
    grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True)
    fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads):
        fast_weights[key] = orderd_params[key] = val - args.task_lr * grad
    '''steps remaining'''
    for k in range(args.train_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)
        loss = criterion(S_out1, S_out2, support['label_final'])
        zero_grad(orderd_params.values())
        grads = torch.autograd.grad(loss,
                                    orderd_params.values(),
                                    allow_unused=True)
        # print('grads:', grads)
        # print("orderd_params.items():", orderd_params.items())
        for (key, val), grad in zip(orderd_params.items(), grads):
            if grad is not None:
                fast_weights[key] = orderd_params[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = neg_dist(XQ, CN)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    return acc_q
def train_one(task, class_names, model, optG, criterion, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    # model['G2'].train()
    # model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    '''分样本对'''
    YS = support['label']
    YQ = query['label']

    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    YS, YQ = reidx_y(args, YS, YQ)
    # print('YS:', support['label'])
    # print('YQ:', query['label'])
    # print("class_names_dict:", class_names_dict['label'])
    """维度填充"""
    if support['text'].shape[1] != class_names_dict['text'].shape[1]:
        zero = torch.zeros(
            (class_names_dict['text'].shape[0],
             support['text'].shape[1] - class_names_dict['text'].shape[1]),
            dtype=torch.long)
        class_names_dict['text'] = torch.cat(
            (class_names_dict['text'], zero.cuda()), dim=-1)

    support['text'] = torch.cat((support['text'], class_names_dict['text']),
                                dim=0)
    support['text_len'] = torch.cat(
        (support['text_len'], class_names_dict['text_len']), dim=0)
    support['label'] = torch.cat((support['label'], class_names_dict['label']),
                                 dim=0)
    # print("support['text']:", support['text'].shape)
    # print("support['label']:", support['label'])

    text_sample_len = support['text'].shape[0]
    # print("support['text'].shape[0]:", support['text'].shape[0])
    support['text_1'] = support['text'][0].view((1, -1))
    support['text_len_1'] = support['text_len'][0].view(-1)
    support['label_1'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_1'] = torch.cat(
                    (support['text_1'], support['text'][i].view((1, -1))),
                    dim=0)
                support['text_len_1'] = torch.cat(
                    (support['text_len_1'], support['text_len'][i].view(-1)),
                    dim=0)
                support['label_1'] = torch.cat(
                    (support['label_1'], support['label'][i].view(-1)), dim=0)

    support['text_2'] = support['text'][0].view((1, -1))
    support['text_len_2'] = support['text_len'][0].view(-1)
    support['label_2'] = support['label'][0].view(-1)
    for i in range(text_sample_len):
        if i == 0:
            for j in range(1, text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)
        else:
            for j in range(text_sample_len):
                support['text_2'] = torch.cat(
                    (support['text_2'], support['text'][j].view((1, -1))),
                    dim=0)
                support['text_len_2'] = torch.cat(
                    (support['text_len_2'], support['text_len'][j].view(-1)),
                    dim=0)
                support['label_2'] = torch.cat(
                    (support['label_2'], support['label'][j].view(-1)), dim=0)

    # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape)
    # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape)
    support['label_final'] = support['label_1'].eq(support['label_2']).int()

    support_1 = {}
    support_1['text'] = support['text_1']
    support_1['text_len'] = support['text_len_1']
    support_1['label'] = support['label_1']

    support_2 = {}
    support_2['text'] = support['text_2']
    support_2['text_len'] = support['text_len_2']
    support_2['label'] = support['label_2']
    # print("**************************************")
    # print("1111111", support['label_1'])
    # print("2222222", support['label_2'])
    # print(support['label_final'])
    '''first step'''
    S_out1, S_out2 = model['G'](support_1, support_2)
    loss = criterion(S_out1, S_out2, support['label_final'])
    zero_grad(model['G'].parameters())
    grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True)
    fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict()
    for (key, val), grad in zip(model['G'].fc.named_parameters(), grads):
        fast_weights[key] = orderd_params[key] = val - args.task_lr * grad
    '''steps remaining'''
    for k in range(args.train_iter - 1):
        S_out1, S_out2 = model['G'](support_1, support_2, fast_weights)
        loss = criterion(S_out1, S_out2, support['label_final'])
        zero_grad(orderd_params.values())
        grads = torch.autograd.grad(loss,
                                    orderd_params.values(),
                                    allow_unused=True)
        # print('grads:', grads)
        # print("orderd_params.items():", orderd_params.items())
        for (key, val), grad in zip(orderd_params.items(), grads):
            if grad is not None:
                fast_weights[key] = orderd_params[
                    key] = val - args.task_lr * grad
    """计算Q上的损失"""
    CN = model['G'].forward_once_with_param(class_names_dict, fast_weights)
    XQ = model['G'].forward_once_with_param(query, fast_weights)
    logits_q = neg_dist(XQ, CN)
    q_loss = model['G'].loss(logits_q, YQ)
    _, pred = torch.max(logits_q, 1)
    acc_q = model['G'].accuracy(pred, YQ)

    optG.zero_grad()
    q_loss.backward()
    optG.step()

    # '把CN过微调过的G, S和Q过G2'
    # CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]
    # # Embedding the document
    # XS = model['G2'](support)  # XS:[N*K, 256(hidden_size*2)]
    # # print("XS:", XS.shape)
    # YS = support['label']
    # # print('YS:', YS)
    #
    # XQ = model['G2'](query)
    # YQ = query['label']
    # # print('YQ:', YQ)
    #
    # YS, YQ = reidx_y(args, YS, YQ)  # 映射标签为从0开始
    #
    # '第二步:用Support更新MLP'
    # for _ in range(args.train_iter):
    #
    #     # Embedding the document
    #     XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 256]
    #
    #     neg_d = neg_dist(XS_mlp, CN)  # [N*K, N]
    #     # print("neg_d:", neg_d.shape)
    #
    #     mlp_loss = model['clf'].loss(neg_d, YS)
    #     # print("mlp_loss:", mlp_loss)
    #
    #     optCLF.zero_grad()
    #     mlp_loss.backward(retain_graph=True)
    #     optCLF.step()
    #
    # '第三步:用Q更新G2'
    # XQ_mlp = model['clf'](XQ)
    # neg_d = neg_dist(XQ_mlp, CN)
    # q_loss = model['clf'].loss(neg_d, YQ)

    # optG2.zero_grad()
    # q_loss.backward()
    # optG2.step()
    #
    # _, pred = torch.max(neg_d, 1)
    # acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return q_loss, acc_q
Beispiel #12
0
def train_one(task, class_names, model, optG, optCLF, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['G'].train()
    model['clf'].train()

    support, query = task
    # print("support, query:", support, query)
    # print("class_names_dict:", class_names_dict)
    sampled_classes = torch.unique(support['label']).cpu().numpy().tolist()
    # print("sampled_classes:", sampled_classes)

    class_names_dict = {}
    class_names_dict['label'] = class_names['label'][sampled_classes]
    # print("class_names_dict['label']:", class_names_dict['label'])
    class_names_dict['text'] = class_names['text'][sampled_classes]
    class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
    class_names_dict['is_support'] = False
    class_names_dict = utils.to_tensor(class_names_dict,
                                       args.cuda,
                                       exclude_keys=['is_support'])

    # Embedding the document
    XS = model['G'](support)  # XS:[N*K, 256(hidden_size*2)]
    # print("XS:", XS.shape)
    YS = support['label']
    # print('YS:', YS)

    CN = model['G'](class_names_dict)  # CN:[N, 256(hidden_size*2)]]
    # print("CN:", CN.shape)

    XQ = model['G'](query)
    YQ = query['label']
    # print('YQ:', YQ)

    YS, YQ = reidx_y(args, YS, YQ)

    for _ in range(args.train_iter):

        # Embedding the document
        XS_mlp = model['clf'](XS)  # [N*K, 256(hidden_size*2)] -> [N*K, 128]

        CN_mlp = model['clf'](CN)  # [N, 256(hidden_size*2)]] -> [N, 128]

        neg_d = neg_dist(XS_mlp, CN_mlp)  # [N*K, N]
        # print("neg_d:", neg_d.shape)

        mlp_loss = model['clf'].loss(neg_d, YS)
        # print("mlp_loss:", mlp_loss)

        optCLF.zero_grad()
        mlp_loss.backward(retain_graph=True)
        optCLF.step()

    XQ_mlp = model['clf'](XQ)
    CN_mlp = model['clf'](CN)
    neg_d = neg_dist(XQ_mlp, CN_mlp)
    g_loss = model['clf'].loss(neg_d, YQ)

    optG.zero_grad()
    g_loss.backward()
    optG.step()

    _, pred = torch.max(neg_d, 1)
    acc_q = model['clf'].accuracy(pred, YQ)

    # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device)
    # print('YQ', set(YQ.numpy()))

    # XSource, XSource_inputD, _ = model['G'](source)
    # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device)

    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    #
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    # d_loss.backward(retain_graph=True)
    # grad['D'].append(get_norm(model['D']))
    # optD.step()
    #
    # # *****************update G****************
    # optG.zero_grad()
    # XQ_logitsD = model['D'](XQ_inputD)
    # XSource_logitsD = model['D'](XSource_inputD)
    # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d)
    #
    # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d)
    #
    # g_loss = loss - d_loss
    # if args.ablation == "-DAN":
    #     g_loss = loss
    #     print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%")
    # g_loss.backward(retain_graph=True)
    # grad['G'].append(get_norm(model['G']))
    # grad['clf'].append(get_norm(model['clf']))
    # optG.step()

    return g_loss, acc_q