Ejemplo n.º 1
0
    def evaluate_test(self):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.load_state_dict(
            torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
        self.model.eval()

        if args.test_mode == 'FSL':
            record = np.zeros((10000, 2))  # loss and acc
            label = torch.arange(args.eval_way).repeat(args.eval_query).type(
                torch.LongTensor)
            if torch.cuda.is_available():
                label = label.cuda()
            with torch.no_grad():
                for i, batch in enumerate(self.test_fsl_loader, 1):
                    if torch.cuda.is_available():
                        data, _ = [_.cuda() for _ in batch]
                    else:
                        data = batch[0]

                    p = args.eval_shot * args.eval_way
                    data_shot, data_query = data[:p], data[p:]
                    logits = self.model.forward_fsl(data_shot, data_query)
                    loss = F.cross_entropy(logits, label)
                    acc = count_acc(logits, label)
                    record[i - 1, 0] = loss.item()
                    record[i - 1, 1] = acc
            assert (i == record.shape[0])
            vl, _ = compute_confidence_interval(record[:, 0])
            va, vap = compute_confidence_interval(record[:, 1])

            self.trlog['test_acc'] = va
            self.trlog['test_acc_interval'] = vap
            self.trlog['test_loss'] = vl

            print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                self.trlog['max_acc_epoch'], self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
            print('Test acc={:.4f} + {:.4f}\n'.format(
                self.trlog['test_acc'], self.trlog['test_acc_interval']))

        else:
            record = np.zeros((10000, 5))  # loss and acc
            label_unseen_query = torch.arange(
                min(args.eval_way,
                    self.valset.num_class)).repeat(args.eval_query).long()
            if torch.cuda.is_available():
                label_unseen_query = label_unseen_query.cuda()
            with torch.no_grad():
                for i, batch in tqdm(
                        enumerate(
                            zip(self.test_gfsl_loader, self.test_fsl_loader),
                            1)):
                    if torch.cuda.is_available():
                        data_seen, data_unseen, seen_label, unseen_label = batch[
                            0][0].cuda(), batch[1][0].cuda(), batch[0][1].cuda(
                            ), batch[1][1].cuda()
                    else:
                        data_seen, data_unseen, seen_label, unseen_label = batch[
                            0][0], batch[1][0], batch[0][1], batch[1][1]
                    p2 = args.eval_shot * args.eval_way
                    data_unseen_shot, data_unseen_query = data_unseen[:
                                                                      p2], data_unseen[
                                                                          p2:]
                    label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:]
                    whole_query = torch.cat([data_seen, data_unseen_query], 0)
                    whole_label = torch.cat([
                        seen_label,
                        label_unseen_query + self.traintestset.num_class
                    ])
                    logits_s, logits_u = self.model.forward_generalized(
                        data_unseen_shot, whole_query)
                    # compute un-biased accuracy
                    new_logits = torch.cat([logits_s, logits_u], 1)
                    record[i - 1, 0] = F.cross_entropy(new_logits,
                                                       whole_label).item()
                    record[i - 1, 1] = count_acc(new_logits, whole_label)
                    # compute harmonic mean
                    HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint(
                        torch.cat([logits_s, logits_u], 1), whole_label,
                        seen_label.shape[0])
                    record[i - 1,
                           2:] = np.array([HM_nobias, SA_nobias, UA_nobias])
                    del logits_s, logits_u, new_logits
                    torch.cuda.empty_cache()

            m_list = []
            p_list = []
            for i in range(5):
                m1, p1 = compute_confidence_interval(record[:, i])
                m_list.append(m1)
                p_list.append(p1)

            self.trlog['test_loss'] = m_list[0]
            self.trlog['test_acc'] = m_list[1]
            self.trlog['test_acc_interval'] = p_list[1]
            self.trlog['test_HM_acc'] = m_list[2]
            self.trlog['test_HM_acc_interval'] = p_list[2]
            self.trlog['test_HMSeen_acc'] = m_list[3]
            self.trlog['test_HMSeen_acc_interval'] = p_list[3]
            self.trlog['test_HMUnseen_acc'] = m_list[4]
            self.trlog['test_HMUnseen_acc_interval'] = p_list[4]

            print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                self.trlog['max_acc_epoch'], self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
            print('Test HM acc={:.4f} + {:.4f}\n'.format(
                self.trlog['test_HM_acc'], self.trlog['test_HM_acc_interval']))
            print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[1], p_list[1]))
            print('GFSL {}-way HM  w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[2], p_list[2]))
            print('GFSL {}-way HMSeen  w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[3], p_list[3]))
            print('GFSL {}-way HMUnseen  w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[4], p_list[4]))
Ejemplo n.º 2
0
    def train(self):
        args = self.args
        # start GFSL training
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()

            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()
            for _, batch in enumerate(
                    zip(self.train_fsl_loader, self.train_gfsl_loader)):
                self.train_step += 1

                if torch.cuda.is_available():
                    support_data, support_label = batch[0][0].cuda(
                    ), batch[0][1].cuda()
                    query_data, query_label = batch[1][0].cuda(
                    ), batch[1][1].cuda()
                else:
                    support_data, support_label = batch[0][0], batch[0][1]
                    query_data, query_label = batch[1][0], batch[1][1]

                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                logits = self.model(support_data, query_data, support_label)
                loss = F.cross_entropy(
                    logits,
                    query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1))
                tl2.add(loss.item())

                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(
                    logits,
                    query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1))

                tl1.add(loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                loss.backward()
                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                self.optimizer.step()
                self.lr_scheduler.step()
                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)
                self.try_logging(tl1, tl2, ta)

                # refresh start_tm
                start_tm = time.time()
                del logits, loss
                torch.cuda.empty_cache()

                self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                self.timer.measure(),
                self.timer.measure(self.train_epoch / args.max_epoch)))

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')
Ejemplo n.º 3
0
    def train(self):
        args = self.args
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()
        
        # start FSL training
        label, label_aux = self.prepare_label()
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()
            if self.args.fix_BN:
                self.model.encoder.eval()
            
            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()
            for batch in self.train_loader:
                self.train_step += 1

                if torch.cuda.is_available():
                    data, gt_label = [_.cuda() for _ in batch]
                else:
                    data, gt_label = batch[0], batch[1]
               
                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                # get saved centers
                logits, reg_logits = self.para_model(data)
                if reg_logits is not None:
                    loss = F.cross_entropy(logits, label)
                    total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux)
                else:
                    loss = F.cross_entropy(logits, label)
                    total_loss = F.cross_entropy(logits, label)
                    
                tl2.add(loss)
                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(logits, label)

                tl1.add(total_loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                total_loss.backward()
                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                self.optimizer.step()
                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)    

                # refresh start_tm
                start_tm = time.time()

            self.lr_scheduler.step()
            self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                    self.timer.measure(),
                    self.timer.measure(self.train_epoch / args.max_epoch))
            )

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')
Ejemplo n.º 4
0
def validate(args, model, val_loader, epoch, trlog=None):
    model.eval()
    global writer
    vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager(
    ), Averager()
    if trlog is not None:
        print('[Dist] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_dist_epoch'], trlog['max_acc_dist']))
        print('[Sim] best epoch {}, current best val acc={:.4f}'.format(
            trlog['max_acc_sim_epoch'], trlog['max_acc_sim']))
    # test performance with Few-Shot
    label = torch.arange(args.num_val_class).repeat(args.query).long()
    if torch.cuda.is_available():
        label = label.cuda()
    with torch.no_grad():
        for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data, _ = batch
            data_shot, data_query = data[:args.num_val_class], data[
                args.num_val_class:]  # 16-way test
            logits_dist, logits_sim = model.forward_proto(
                data_shot, data_query, args.num_val_class)
            loss_dist = F.cross_entropy(logits_dist, label)
            acc_dist = count_acc(logits_dist, label)
            loss_sim = F.cross_entropy(logits_sim, label)
            acc_sim = count_acc(logits_sim, label)
            vl_dist.add(loss_dist.item())
            va_dist.add(acc_dist)
            vl_sim.add(loss_sim.item())
            va_sim.add(acc_sim)

    vl_dist = vl_dist.item()
    va_dist = va_dist.item()
    vl_sim = vl_sim.item()
    va_sim = va_sim.item()

    print(
        'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}'
        .format(epoch, vl_dist, va_dist, vl_sim, va_sim))

    if trlog is not None:
        writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch)
        writer.add_scalar('data/val_acc_dist', float(va_dist), epoch)
        writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch)
        writer.add_scalar('data/val_acc_sim', float(va_sim), epoch)

        if va_dist > trlog['max_acc_dist']:
            trlog['max_acc_dist'] = va_dist
            trlog['max_acc_dist_epoch'] = epoch
            save_model('max_acc_dist', model, args)
            save_checkpoint(True)

        if va_sim > trlog['max_acc_sim']:
            trlog['max_acc_sim'] = va_sim
            trlog['max_acc_sim_epoch'] = epoch
            save_model('max_acc_sim', model, args)
            save_checkpoint(True)

        trlog['val_loss_dist'].append(vl_dist)
        trlog['val_acc_dist'].append(va_dist)
        trlog['val_loss_sim'].append(vl_sim)
        trlog['val_acc_sim'].append(va_sim)
        return trlog
Ejemplo n.º 5
0
    def evaluate_test(self):
        # restore model args
        emb_dim = self.emb_dim
        args = self.args
        weights = torch.load(
            osp.join(self.args.save_path, self.args.weight_name))
        model_weights = weights['params']
        self.missing_keys, self.unexpected_keys = self.model.load_state_dict(
            model_weights, strict=False)
        self.model.eval()

        test_steps = 600

        self.record = np.zeros((test_steps, 2))  # loss and acc
        self.auroc_record = np.zeros((test_steps, 10))
        label = torch.arange(args.closed_way,
                             dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()

        way = args.closed_way
        label = torch.arange(way).repeat(15).cuda()

        for i, batch in tqdm(enumerate(self.test_loader, 1)):
            if i > test_steps:
                break

            if torch.cuda.is_available():
                data, dlabel = [_.cuda() for _ in batch]
            else:
                data = batch[0]

            self.probe_data = data
            self.probe_dlabel = dlabel

            with torch.no_grad():
                _ = self.para_model(data)
                instance_embs = self.para_model.probe_instance_embs
                support_idx = self.para_model.probe_support_idx
                query_idx = self.para_model.probe_query_idx

                support = instance_embs[support_idx.flatten()].view(
                    *(support_idx.shape + (-1, )))
                query = instance_embs[query_idx.flatten()].view(
                    *(query_idx.shape + (-1, )))
                emb_dim = support.shape[-1]

                support = support[:, :, :way].contiguous()
                # get mean of the support
                bproto = support.mean(dim=1)  # Ntask x NK x d
                proto = bproto

                kquery = query[:, :, :way].contiguous()
                uquery = query[:, :, way:].contiguous()

                # get mean of the support
                proto = self.para_model.slf_attn(proto, proto, proto)
                proto = proto[0]

            klogits = -(kquery.reshape(-1, 1, emb_dim) -
                        proto).pow(2).sum(2) / 64.0
            ulogits = -(uquery.reshape(-1, 1, emb_dim) -
                        proto).pow(2).sum(2) / 64.0

            loss = F.cross_entropy(klogits, label)
            acc = count_acc(klogits, label)
            """ Probability """
            known_prob = F.softmax(klogits, 1).max(1)[0]
            unknown_prob = F.softmax(ulogits, 1).max(1)[0]

            known_scores = (known_prob).cpu().detach().numpy()
            unknown_scores = (unknown_prob).cpu().detach().numpy()
            known_scores = 1 - known_scores
            unknown_scores = 1 - unknown_scores

            auroc = calc_auroc(known_scores, unknown_scores)
            """ Distance """
            kdist = -(klogits.max(1)[0])
            udist = -(ulogits.max(1)[0])
            kdist = kdist.cpu().detach().numpy()
            udist = udist.cpu().detach().numpy()
            dist_auroc = calc_auroc(kdist, udist)
            """ Snatcher """
            with torch.no_grad():
                snatch_known = []
                for j in range(75):
                    pproto = bproto.clone().detach()
                    """ Algorithm 1 Line 1 """
                    c = klogits.argmax(1)[j]
                    """ Algorithm 1 Line 2 """
                    pproto[0][c] = kquery.reshape(-1, emb_dim)[j]
                    """ Algorithm 1 Line 3 """
                    pproto = self.para_model.slf_attn(pproto, pproto,
                                                      pproto)[0]
                    pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0
                    """ pdiff: d_SnaTCHer in Algorithm 1 """
                    snatch_known.append(pdiff)

                snatch_unknown = []
                for j in range(ulogits.shape[0]):
                    pproto = bproto.clone().detach()
                    """ Algorithm 1 Line 1 """
                    c = ulogits.argmax(1)[j]
                    """ Algorithm 1 Line 2 """
                    pproto[0][c] = uquery.reshape(-1, emb_dim)[j]
                    """ Algorithm 1 Line 3 """
                    pproto = self.para_model.slf_attn(pproto, pproto,
                                                      pproto)[0]
                    pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0
                    """ pdiff: d_SnaTCHer in Algorithm 1 """
                    snatch_unknown.append(pdiff)

                pkdiff = torch.stack(snatch_known)
                pudiff = torch.stack(snatch_unknown)
                pkdiff = pkdiff.cpu().detach().numpy()
                pudiff = pudiff.cpu().detach().numpy()

                snatch_auroc = calc_auroc(pkdiff, pudiff)

            self.record[i - 1, 0] = loss.item()
            self.record[i - 1, 1] = acc
            self.auroc_record[i - 1, 0] = auroc
            self.auroc_record[i - 1, 1] = snatch_auroc
            self.auroc_record[i - 1, 2] = dist_auroc

            if i % 100 == 0:
                vdata = self.record[:, 1]
                vdata = 1.0 * np.array(vdata)
                vdata = vdata[:i]
                va = np.mean(vdata)
                std = np.std(vdata)
                vap = 1.96 * (std / np.sqrt(i))

                audata = self.auroc_record[:, 0]
                audata = np.array(audata, np.float32)
                audata = audata[:i]
                aua = np.mean(audata)
                austd = np.std(audata)
                auap = 1.96 * (austd / np.sqrt(i))

                sdata = self.auroc_record[:, 1]
                sdata = np.array(sdata, np.float32)
                sdata = sdata[:i]
                sa = np.mean(sdata)
                sstd = np.std(sdata)
                sap = 1.96 * (sstd / np.sqrt(i))

                ddata = self.auroc_record[:, 2]
                ddata = np.array(ddata, np.float32)[:i]
                da = np.mean(ddata)
                dstd = np.std(ddata)
                dap = 1.96 * (dstd / np.sqrt(i))

                print("acc: {:.4f} + {:.4f} Prob: {:.4f} + {:.4f} Dist: {:.4f} + {:.4f} SnaTCHer: {:.4f} + {:.4f}"\
                      .format(va, vap, aua, auap, da, dap, sa, sap))

        return
Ejemplo n.º 6
0
    def evaluate_gfsl(self):
        args = self.args
        label_unseen_query = torch.arange(args.eval_way).repeat(
            args.eval_query).long()
        if torch.cuda.is_available():
            label_unseen_query = label_unseen_query.cuda()

        generalized_few_shot_acc = np.zeros((args.num_eval_episodes, 2))
        generalized_few_shot_delta = np.zeros((args.num_eval_episodes, 4))
        generalized_few_shot_hmeanacc = np.zeros((args.num_eval_episodes, 6))
        generalized_few_shot_hmeanmap = np.zeros((args.num_eval_episodes, 6))
        generalized_few_shot_ausuc = np.zeros((args.num_eval_episodes, 1))
        AUC_record = []

        for i, batch in tqdm(
                enumerate(zip(self.test_gfsl_loader, self.test_fsl_loader),
                          1)):
            if torch.cuda.is_available():
                data_seen, data_unseen, seen_label, unseen_label = batch[0][
                    0].cuda(), batch[1][0].cuda(), batch[0][1].cuda(
                    ), batch[1][1].cuda()
            else:
                data_seen, data_unseen, seen_label, unseen_label = batch[0][
                    0], batch[1][0], batch[0][1], batch[1][1]
            p2 = args.eval_shot * args.eval_way

            data_unseen_shot, data_unseen_query = data_unseen[:
                                                              p2], data_unseen[
                                                                  p2:]
            label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:]
            whole_query = torch.cat([data_seen, data_unseen_query], 0)
            whole_label = torch.cat(
                [seen_label, label_unseen_query + self.trainset.num_class])
            if args.model_class in ['CLS', 'Castle', 'ACastle']:
                with torch.no_grad():
                    logits_s, logits_u = self.model.forward_generalized(
                        data_unseen_shot, whole_query)
            elif args.model_class in ['ProtoNet']:
                with torch.no_grad():
                    logits_s, logits_u = self.model.forward_generalized(
                        data_unseen_shot, whole_query, self.model.seen_proto)
            # compute un-biased accuracy
            new_logits = torch.cat([logits_s, logits_u], 1)
            if 'acc' in self.criteria or 'hmeanacc' in self.criteria or 'delta' in self.criteria:
                new_logits_acc_biased = torch.cat(
                    [logits_s - self.best_bias_acc, logits_u], 1)
            if 'hmeanmap' in self.criteria:
                new_logits_map_biased = torch.cat(
                    [logits_s - self.best_bias_map, logits_u], 1)
            # Criterion: Acc
            if 'acc' in self.criteria:
                generalized_few_shot_acc[i - 1, 0] = count_acc(
                    new_logits, whole_label)
                # compute biased accuracy
                generalized_few_shot_acc[i - 1, 1] = count_acc(
                    new_logits_acc_biased, whole_label)

            if 'delta' in self.criteria:
                # compute delta value for un-biased logits
                unbiased_detla1, unbiased_detla2 = count_delta_value(
                    new_logits, whole_label, seen_label.shape[0],
                    self.trainset.num_class)
                # compute delta value
                biased_detla1, biased_detla2 = count_delta_value(
                    new_logits_acc_biased, whole_label, seen_label.shape[0],
                    self.trainset.num_class)
                generalized_few_shot_delta[i - 1, :] = np.array([
                    unbiased_detla1, unbiased_detla2, biased_detla1,
                    biased_detla2
                ])

            if 'hmeanacc' in self.criteria:
                # compute harmonic mean
                HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint(
                    new_logits, whole_label, seen_label.shape[0])
                HM, SA, UA = count_acc_harmonic_low_shot_joint(
                    new_logits_acc_biased, whole_label, seen_label.shape[0])
                generalized_few_shot_hmeanacc[i - 1, :] = np.array(
                    [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA])

            if 'hmeanmap' in self.criteria:
                # compute harmonic mean
                HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_MAP(
                    new_logits, whole_label, seen_label.shape[0], 'macro')
                HM, SA, UA = count_acc_harmonic_MAP(new_logits_map_biased,
                                                    whole_label,
                                                    seen_label.shape[0],
                                                    'macro')
                generalized_few_shot_hmeanmap[i - 1, :] = np.array(
                    [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA])

            if 'ausuc' in self.criteria:
                # compute AUSUC
                generalized_few_shot_ausuc[i - 1,
                                           0], temp_auc_record = Compute_AUSUC(
                                               logits_s.detach().cpu().numpy(),
                                               logits_u.detach().cpu().numpy(),
                                               whole_label.cpu().numpy(),
                                               np.arange(
                                                   self.trainset.num_class),
                                               self.trainset.num_class +
                                               np.arange(args.eval_way))
                AUC_record.append(temp_auc_record)

            del logits_s, logits_u, new_logits
            torch.cuda.empty_cache()

        self.AUC_record = AUC_record
        print('-'.join([args.model_class, args.model_path]))
        if 'acc' in self.criteria:
            self.trlog['acc_mean'], self.trlog[
                'acc_interval'] = compute_confidence_interval(
                    generalized_few_shot_acc[:, 0])
            self.trlog['acc_biased_mean'], self.trlog[
                'acc_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_acc[:, 1])
            print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['acc_mean'],
                self.trlog['acc_interval']))
            print('GFSL {}-way Acc w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['acc_biased_mean'],
                self.trlog['acc_biased_interval']))

        if 'delta' in self.criteria:
            self.trlog['detla1_mean'], self.trlog[
                'detla1_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 0])
            self.trlog['detla2_mean'], self.trlog[
                'detla2_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 1])
            self.trlog['detla1_biased_mean'], self.trlog[
                'detla1_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 2])
            self.trlog['detla2_biased_mean'], self.trlog[
                'detla2_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 3])
            print('GFSL {}-way Detla1 w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla1_mean'],
                self.trlog['detla1_interval']))
            print('GFSL {}-way Detla1 w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla1_biased_mean'],
                self.trlog['detla1_biased_interval']))
            print('GFSL {}-way Detla2 w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla2_mean'],
                self.trlog['detla2_interval']))
            print('GFSL {}-way Detla2 w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla2_biased_mean'],
                self.trlog['detla2_biased_interval']))

        if 'hmeanacc' in self.criteria:
            self.trlog['HM_mean'], self.trlog[
                'HM_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 0])
            self.trlog['S2All_mean'], self.trlog[
                'S2All_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 1])
            self.trlog['U2All_mean'], self.trlog[
                'U2All_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 2])
            self.trlog['HM_biased_mean'], self.trlog[
                'HM_biased_nterval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 3])
            self.trlog['S2All_biased_mean'], self.trlog[
                'S2All_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 4])
            self.trlog['U2All_biased_mean'], self.trlog[
                'U2All_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 5])
            print('GFSL {}-way HM_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_mean'],
                self.trlog['HM_interval']))
            print('GFSL {}-way HM_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_biased_mean'],
                self.trlog['HM_biased_nterval']))
            print('GFSL {}-way S2All_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_mean'],
                self.trlog['S2All_interval']))
            print('GFSL {}-way S2All_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_biased_mean'],
                self.trlog['S2All_biased_interval']))
            print('GFSL {}-way U2All_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_mean'],
                self.trlog['U2All_interval']))
            print('GFSL {}-way U2All_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_biased_mean'],
                self.trlog['U2All_biased_interval']))

        if 'hmeanmap' in self.criteria:
            self.trlog['HM_map_mean'], self.trlog[
                'HM_map_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 0])
            self.trlog['S2All_map_mean'], self.trlog[
                'S2All_map_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 1])
            self.trlog['U2All_map_mean'], self.trlog[
                'U2All_map_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 2])
            self.trlog['HM_map_biased_mean'], self.trlog[
                'HM_map_biased_nterval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 3])
            self.trlog['S2All_map_biased_mean'], self.trlog[
                'S2All_map_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 4])
            self.trlog['U2All_map_biased_mean'], self.trlog[
                'U2All_map_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 5])
            print('GFSL {}-way HM_map_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_map_mean'],
                self.trlog['HM_map_interval']))
            print('GFSL {}-way HM_map_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_map_biased_mean'],
                self.trlog['HM_map_biased_nterval']))
            print('GFSL {}-way S2All_map_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_map_mean'],
                self.trlog['S2All_map_interval']))
            print('GFSL {}-way S2All_map_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_map_biased_mean'],
                self.trlog['S2All_map_biased_interval']))
            print('GFSL {}-way U2All_map_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_map_mean'],
                self.trlog['U2All_map_interval']))
            print('GFSL {}-way U2All_map_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_map_biased_mean'],
                self.trlog['U2All_map_biased_interval']))

        if 'ausuc' in self.criteria:
            self.trlog['AUSUC_mean'], self.trlog[
                'AUSUC_interval'] = compute_confidence_interval(
                    generalized_few_shot_ausuc[:, 0])
            print('GFSL {}-way AUSUC {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['AUSUC_mean'],
                self.trlog['AUSUC_interval']))
Ejemplo n.º 7
0
        model.train()
        tl = Averager()
        ta = Averager()

        for i, batch in enumerate(train_loader, 1):
            global_count = global_count + 1
            if torch.cuda.is_available():
                data, label = [_.cuda() for _ in batch]
                label = label.type(torch.cuda.LongTensor)
            else:
                data, label = batch
                label = label.type(torch.LongTensor)
            logits = model(data)
            loss = criterion(logits, label)
            acc = count_acc(logits, label)
            writer.add_scalar('data/loss', float(loss), global_count)
            writer.add_scalar('data/acc', float(acc), global_count)
            if (i - 1) % 100 == 0:
                print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(
                    epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tl = tl.item()
        ta = ta.item()
Ejemplo n.º 8
0
    def train(self):
        args = self.args
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()
        
        # start FSL training
        label, label_aux = self.prepare_label()
        for epoch in range(1, args.max_epoch + 1):
            self.train_epoch += 1
            self.model.train()
            if self.args.fix_BN:
                self.model.encoder.eval()
            
            tl1 = Averager()
            tl2 = Averager()
            ta = Averager()

            start_tm = time.time()

            for batch in tqdm(self.train_loader):

                data, gt_label = [_.cuda() for _ in batch]
                data_tm = time.time()
                self.dt.add(data_tm - start_tm)

                # get saved centers
                logits, reg_logits = self.para_model(data)
                logits = logits.view(-1, args.way)
                oh_query = torch.nn.functional.one_hot(label, args.way)

                sims = logits
                temp = (sims * oh_query).sum(-1)
                e_sim_p = temp - self.model.margin
                e_sim_p_pos = F.relu(e_sim_p)
                e_sim_p_neg = F.relu(-e_sim_p)

                l_open_margin = e_sim_p_pos.mean(-1)
                l_open = e_sim_p_neg.mean(-1)
                if reg_logits is not None:
                    loss = F.cross_entropy(logits, label)
                    total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux)
                else:
                    loss = F.cross_entropy(logits, label)
                total_loss = total_loss + args.open_balance * l_open
                tl2.add(loss)
                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(logits, label)

                tl1.add(total_loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                total_loss.backward(retain_graph=True)

                self.optimizer_margin.zero_grad()

                l_open_margin.backward()
                self.optimizer.step()
                self.optimizer_margin.step()

                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)

                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)
                # refresh start_tm
                start_tm = time.time()

            print('lr: {:.4f} Total_loss: {:.4f} ce_loss {:.4f} l_open: {:4f} R: {:4f}\n'.format(self.optimizer_margin.param_groups[0]['lr'],\
                total_loss.item(), loss.item(), l_open.item(), self.model.margin.item()))
            self.lr_scheduler.step()
            self.lr_scheduler_margin.step()

            self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                    self.timer.measure(),
                    self.timer.measure(self.train_epoch / args.max_epoch))
            )

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')
Ejemplo n.º 9
0
    def open_evaluate(self, data_loader):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.eval()
        record = np.zeros((args.num_test_episodes, 4))  # loss and acc

        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        print('Evaluating ... best epoch {}, SnaTCHer={:.4f} + {:.4f}  acc={:.4f} + {:.4f}'.format(
            self.trlog['max_auc_epoch'],
            self.trlog['max_auc'],
            self.trlog['max_auc_interval'],
            self.trlog['acc'],
            self.trlog['acc_interval']))

        with torch.no_grad():
            for i, batch in enumerate(tqdm(data_loader)):

                data, _ = [_.cuda() for _ in batch]

                logits = self.para_model(data)
                logits = logits.reshape([-1, args.eval_way + args.open_eval_way, args.way])
                klogits = logits[:, :args.eval_way, :].reshape(-1, args.way)
                ulogits = logits[:, args.eval_way:, :].reshape(-1, args.way)
                loss = F.cross_entropy(klogits, label)
                acc = count_acc(klogits, label)

                """ Distance """
                kdist = -(klogits.max(1)[0])
                udist = -(ulogits.max(1)[0])
                kdist = kdist.cpu().detach().numpy()
                udist = udist.cpu().detach().numpy()
                dist_auroc = calc_auroc(kdist, udist)

                """ Snatcher """
                with torch.no_grad():
                    instance_embs = self.para_model.instance_embs
                    support_idx = self.para_model.support_idx
                    query_idx = self.para_model.query_idx

                    support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + (-1,)))
                    query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + (-1,)))
                    emb_dim = support.shape[-1]

                    support = support[:, :, :args.way].contiguous()
                    # get mean of the support
                    bproto = support.mean(dim=1)  # Ntask x NK x d
                    proto = self.para_model.slf_attn(bproto, bproto, bproto)
                    kquery = query[:, :, :args.way].contiguous()
                    uquery = query[:, :, args.way:].contiguous()
                    snatch_known = []
                    for j in range(75):
                        pproto = bproto.clone().detach()
                        """ Algorithm 1 Line 1 """
                        c = klogits.argmax(1)[j]
                        """ Algorithm 1 Line 2 """
                        pproto[0][c] = kquery.reshape(-1, emb_dim)[j]
                        """ Algorithm 1 Line 3 """
                        pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0]
                        pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0
                        """ pdiff: d_SnaTCHer in Algorithm 1 """
                        snatch_known.append(pdiff)

                    snatch_unknown = []
                    for j in range(ulogits.shape[0]):
                        pproto = bproto.clone().detach()
                        """ Algorithm 1 Line 1 """
                        c = ulogits.argmax(1)[j]
                        """ Algorithm 1 Line 2 """
                        pproto[0][c] = uquery.reshape(-1, emb_dim)[j]
                        """ Algorithm 1 Line 3 """
                        pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0]
                        pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0
                        """ pdiff: d_SnaTCHer in Algorithm 1 """
                        snatch_unknown.append(pdiff)

                    pkdiff = torch.stack(snatch_known)
                    pudiff = torch.stack(snatch_unknown)
                    pkdiff = pkdiff.cpu().detach().numpy()
                    pudiff = pudiff.cpu().detach().numpy()

                    snatch_auroc = calc_auroc(pkdiff, pudiff)
                record[i - 1, 0] = loss.item()
                record[i - 1, 1] = acc
                record[i - 1, 2] = snatch_auroc
                record[i - 1, 3] = dist_auroc

        vl, _ = compute_confidence_interval(record[:, 0])
        va, vap = compute_confidence_interval(record[:, 1])
        auc_sna, auc_sna_p = compute_confidence_interval(record[:, 2])
        auc_dist, auc_dist_p = compute_confidence_interval(record[:, 3])
        print("acc: {:.4f} + {:.4f} Dist: {:.4f} + {:.4f} SnaTCHer: {:.4f} + {:.4f}" \
              .format(va, vap, auc_dist, auc_dist_p, auc_sna, auc_sna_p))
        # train mode
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()

        return vl, va, vap, auc_sna, auc_sna_p