Пример #1
0
    def train(self):
        """The function for the pre-train phase."""

        # Set the pretrain log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Start pretrain
        for epoch in range(1, self.args.pre_max_epoch + 1):
            # Set the model to train mode

            print('Epoch {}'.format(epoch))
            self.model.train()
            self.model.mode = 'pre'
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Using tqdm to read samples from train loader

            tqdm_gen = tqdm.tqdm(self.train_loader)
            #for i, batch in enumerate(self.train_loader):
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                label = batch[1]
                if torch.cuda.is_available():
                    label = label.type(torch.cuda.LongTensor)
                else:
                    label = label.type(torch.LongTensor)
                logits = self.model(data)
                loss = F.cross_entropy(logits, label)
                # Calculate train accuracy
                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)
                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # start the original evaluation
            self.model.eval()
            self.model.mode = 'origval'

            _, valid_results = self.val_orig(self.valset.X_val,
                                             self.valset.y_val)
            print('validation accuracy ', valid_results[0])

            # Start validation for this epoch, set model to eval mode
            self.model.eval()
            self.model.mode = 'preval'

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)
            label_shot = torch.arange(self.args.way).repeat(self.args.shot)
            if torch.cuda.is_available():
                label_shot = label_shot.type(torch.cuda.LongTensor)
            else:
                label_shot = label_shot.type(torch.LongTensor)

            # Run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                #data=data.float()
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
        writer.close()
Пример #2
0
    def train(self):
        """The function for the meta-train phase."""

        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['train_iou'] = []
        trlog['val_iou'] = []
        trlog['max_iou'] = 0.0
        trlog['max_iou_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)
                
        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()
            train_iou_averager = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            self._reset_metrics()
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number 
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, labels,_ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    labels=batch[1]
                    
                p = self.args.way*self.args.shot
                data_shot, data_query = data[:p], data[p:]
                label_shot,label=labels[:p],labels[p:]
                # Output logits for model
                par=data_shot, label_shot, data_query
                logits = self.model(par)
                
                # Calculate meta-train loss
                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits,label)
                
                # Calculate meta-train accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label, self.args.way)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()
                
                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(pixAcc)
                train_iou_averager.add(mIoU)

                # Print loss and accuracy till this step
                tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, train_loss_averager.item(), train_acc_averager.item()*100.0,train_iou_averager.item()))
                
                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()
            train_iou_averager = train_iou_averager.item()

            writer.add_scalar('data/train_loss (Meta)', float(train_loss_averager), epoch)
            writer.add_scalar('data/train_acc (Meta)', float(train_acc_averager)*100.0, epoch)  
            writer.add_scalar('data/train_iou (Meta)', float(train_iou_averager), epoch)
            
            # Start validation for this epoch, set model to eval mode
            self.model.eval()

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()
            val_iou_averager = Averager()
                
            # Print previous information
            if epoch % 1 == 0:
                print('Best Val Epoch {}, Best Val IoU={:.4f}'.format(trlog['max_iou_epoch'], trlog['max_iou']))
                
            # Run meta
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, labels,_ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    labels=batch[1]
                p = self.args.way* self.args.shot
                data_shot, data_query = data[:p], data[p:]
                label_shot,label=labels[:p],labels[p:]
                
                par=data_shot, label_shot, data_query
                logits = self.model(par)
                
                # Calculate meta val loss
                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits,label)
                
                # Calculate meta-val accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label, self.args.way)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()

                val_loss_averager.add(loss.item())
                val_acc_averager.add(pixAcc)
                val_iou_averager.add(mIoU)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            val_iou_averager = val_iou_averager.item()
            
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss (Meta)', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc (Meta)', float(val_acc_averager)*100.0, epoch)  
            writer.add_scalar('data/val_iou (Meta)', float(val_iou_averager), epoch)
            
            # Print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, val_loss_averager, val_acc_averager*100.0,val_iou_averager))

            # Update best saved model
            if val_iou_averager > trlog['max_iou']:
                trlog['max_iou'] = val_iou_averager
                trlog['max_iou_epoch'] = epoch
                self.save_model('max_iou')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch'+str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)
            trlog['train_iou'].append(train_iou_averager)
            trlog['val_iou'].append(val_iou_averager)
            
            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 1 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch)))

        writer.close()
Пример #3
0
    def eval(self):
        """The function for the meta-evaluate (test) phase."""
        # Load the logs
        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        self.test_set = mDataset('test', self.args)
        self.sampler = CategoriesSampler(self.test_set.labeln, self.args.num_batch, self.args.way, self.args.teshot + self.args.test_query, self.args.teshot)
        self.loader = DataLoader(dataset=self.test_set, batch_sampler=self.sampler, num_workers=8, pin_memory=True)
        #self.loader = DataLoader(dataset=self.test_set,batch_size=10, shuffle=False, num_workers=8, pin_memory=True)
        # Set test accuracy recorder
        #test_acc_record = np.zeros((600,))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_iou' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Start meta-test
        self._reset_metrics()
        count=1
        for i, batch in enumerate(self.loader, 1):
            if torch.cuda.is_available():
                data, labels,_ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
                labels=batch[1]
            p = self.args.teshot*self.args.way
            data_shot, data_query = data[:p], data[p:]
            label_shot,label=labels[:p],labels[p:]
            logits = self.model((data_shot, label_shot, data_query))
            seg_metrics = eval_metrics(logits, label, self.args.way)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()
            
            ave_acc.add(pixAcc)
            #test_acc_record[i-1] = acc
            #if i % 100 == 0:
                #print('batch {}: {Average Accuracy:.2f}({Pixel Accuracy:.2f} {IoU :.2f} )'.format(i, ave_acc.item() * 100.0, pixAcc * 100.0,mIoU))
                
            #Saving Test Image, Ground Truth Image and Predicted Image
            for j in range(len(data_query)):
                
                x1 = data_query[j].detach().cpu()
                y1 = label[j].detach().cpu()
                z1 = logits[j].detach().cpu()
                
                x = transforms.ToPILImage()(x1).convert("RGB")
                y = transforms.ToPILImage()(y1 /(1.0*(self.args.way-1))).convert("LA")
                im =  torch.tensor(np.argmax(np.array(z1),axis=0)/(1.0*(self.args.way-1))) 
                im =  im.type(torch.FloatTensor)
                z =  transforms.ToPILImage()(im).convert("LA")
                
                px=self.args.save_image_dir+str(count)+'a.jpg'
                py=self.args.save_image_dir+str(count)+'b.png'
                pz=self.args.save_image_dir+str(count)+'c.png'
                x.save(px)
                y.save(py)
                z.save(pz)
                count=count+1
Пример #4
0
    def eval(self, gradcam=False, rise=False, test_on_val=False):
        """The function for the meta-eval phase."""
        # Load the logs
        if os.path.exists(osp.join(self.args.save_path, 'trlog')):
            trlog = torch.load(osp.join(self.args.save_path, 'trlog'))
        else:
            trlog = None

        torch.manual_seed(1)
        np.random.seed(1)
        # Load meta-test set
        test_set = Dataset('val' if test_on_val else 'test', self.args)
        sampler = CategoriesSampler(test_set.label, 600, self.args.way,
                                    self.args.shot + self.args.val_query)
        loader = DataLoader(test_set,
                            batch_sampler=sampler,
                            num_workers=8,
                            pin_memory=True)

        # Set test accuracy recorder
        test_acc_record = np.zeros((600, ))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            weights = self.addOrRemoveModule(
                self.model,
                torch.load(self.args.eval_weights)['params'])
            self.model.load_state_dict(weights)
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_acc' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Generate labels
        label = torch.arange(self.args.way).repeat(self.args.val_query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        if gradcam:
            self.model.layer3 = self.model.encoder.layer3
            model_dict = dict(type="resnet",
                              arch=self.model,
                              layer_name='layer3')
            grad_cam = GradCAM(model_dict, True)
            grad_cam_pp = GradCAMpp(model_dict, True)
            self.model.features = self.model.encoder
            guided = GuidedBackprop(self.model)
        if rise:
            self.model.layer3 = self.model.encoder.layer3
            score_mod = ScoreCam(self.model)

        # Start meta-test
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = self.args.way * self.args.shot
            data_shot, data_query = data[:k], data[k:]

            if i % 5 == 0:
                suff = "_val" if test_on_val else ""

                if self.args.rep_vec or self.args.cross_att:
                    print('batch {}: {:.2f}({:.2f})'.format(
                        i,
                        ave_acc.item() * 100, acc * 100))

                    if self.args.cross_att:
                        label_one_hot = self.one_hot(label).to(label.device)
                        _, _, logits, simMapQuer, simMapShot, normQuer, normShot = self.model(
                            (data_shot, label_shot, data_query),
                            ytest=label_one_hot,
                            retSimMap=True)
                    else:
                        logits, simMapQuer, simMapShot, normQuer, normShot, fast_weights = self.model(
                            (data_shot, label_shot, data_query),
                            retSimMap=True)

                    torch.save(
                        simMapQuer,
                        "../results/{}/{}_simMapQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        simMapShot,
                        "../results/{}/{}_simMapShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        data_query, "../results/{}/{}_dataQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        data_shot, "../results/{}/{}_dataShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normQuer, "../results/{}/{}_normQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normShot, "../results/{}/{}_normShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                else:
                    logits, normQuer, normShot, fast_weights = self.model(
                        (data_shot, label_shot, data_query),
                        retFastW=True,
                        retNorm=True)
                    torch.save(
                        normQuer, "../results/{}/{}_normQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normShot, "../results/{}/{}_normShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))

                if gradcam:
                    print("Saving gradmaps", i)
                    allMasks, allMasks_pp, allMaps = [], [], []
                    for l in range(len(data_query)):
                        allMasks.append(
                            grad_cam(data_query[l:l + 1], fast_weights, None))
                        allMasks_pp.append(
                            grad_cam_pp(data_query[l:l + 1], fast_weights,
                                        None))
                        allMaps.append(
                            guided.generate_gradients(data_query[l:l + 1],
                                                      fast_weights))
                    allMasks = torch.cat(allMasks, dim=0)
                    allMasks_pp = torch.cat(allMasks_pp, dim=0)
                    allMaps = torch.cat(allMaps, dim=0)

                    torch.save(
                        allMasks, "../results/{}/{}_gradcamQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        allMasks_pp,
                        "../results/{}/{}_gradcamppQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        allMaps, "../results/{}/{}_guidedQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))

                if rise:
                    print("Saving risemaps", i)
                    allScore = []
                    for l in range(len(data_query)):
                        allScore.append(
                            score_mod(data_query[l:l + 1], fast_weights))

            else:
                if self.args.cross_att:
                    label_one_hot = self.one_hot(label).to(label.device)
                    _, _, logits = self.model(
                        (data_shot, label_shot, data_query),
                        ytest=label_one_hot)
                else:
                    logits = self.model((data_shot, label_shot, data_query))

            acc = count_acc(logits, label)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc

        # Calculate the confidence interval, update the logs
        m, pm = compute_confidence_interval(test_acc_record)
        if trlog is not None:
            print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(
                trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item()))
        print('Test Acc {:.4f} + {:.4f}'.format(m, pm))

        return m
Пример #5
0
    def train(self, trial):
        """The function for the meta-train phase."""

        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Generate the labels for train set of the episodes
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        worstClasses = []

        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-train updates
            label = torch.arange(self.args.way).repeat(self.args.train_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, targ = [_.cuda() for _ in batch]
                else:
                    data, targ = batch
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                # Output logits for model
                if self.args.cross_att:
                    label_one_hot = self.one_hot(label).to(label.device)
                    ytest, cls_scores, logits = self.model(
                        (data_shot, label_shot, data_query),
                        ytest=label_one_hot)
                    pids = label_shot
                    loss = self.crossAttLoss(ytest, cls_scores, label, pids)
                    logits = logits[0]
                else:
                    logits = self.model((data_shot, label_shot, data_query))
                    # Calculate meta-train loss
                    loss = F.cross_entropy(logits, label)

                if self.args.distill_id:
                    teachLogits = self.teacher(
                        (data_shot, label_shot, data_query))
                    kl = F.kl_div(F.log_softmax(logits / self.args.kl_temp,
                                                dim=1),
                                  F.softmax(teachLogits / self.args.kl_temp,
                                            dim=1),
                                  reduction="batchmean")
                    loss = (kl * self.args.kl_interp * self.args.kl_temp *
                            self.args.kl_temp + loss *
                            (1 - self.args.kl_interp))

                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f}'.format(
                        epoch, loss.item(), acc))

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if self.args.hard_tasks:
                    if len(worstClasses) == self.args.way:
                        inds = self.train_sampler.hardBatch(worstClasses)
                        batch = [self.trainset[i][0] for i in inds]
                        data_shot, data_query = data[:p], data[p:]
                        logits = self.model(
                            (data_shot, label_shot, data_query))
                        loss = F.cross_entropy(logits, label)
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()
                        worstClasses = []
                    else:
                        error_mat = (logits.argmax(dim=1) == label).view(
                            self.args.train_query, self.args.way)
                        worst = error_mat.float().mean(dim=0).argmin()
                        worst_trueInd = targ[worst]
                        worstClasses.append(worst_trueInd)

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # Start validation for this epoch, set model to eval mode
            self.model.eval()

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-val for this epoch
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val Acc={:.4f}'.format(
                    trlog['max_acc_epoch'], trlog['max_acc']))
            # Run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]

                if self.args.cross_att:
                    label_one_hot = self.one_hot(label).to(label.device)
                    ytest, cls_scores, logits = self.model(
                        (data_shot, label_shot, data_query),
                        ytest=label_one_hot)
                    pids = label_shot
                    loss = self.crossAttLoss(ytest, cls_scores, label, pids)
                    logits = logits[0]
                else:
                    logits = self.model((data_shot, label_shot, data_query))
                    loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)

                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
            # Print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager))

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))

            trial.report(val_acc_averager, epoch)

        writer.close()
Пример #6
0
    def eval(self):
        """The function for the meta-eval phase."""
        # Load the logs
        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        test_set = Dataset('test', self.args)
        sampler = CategoriesSampler(test_set.label, 600, self.args.way,
                                    self.args.shot + self.args.val_query)
        loader = DataLoader(test_set,
                            batch_sampler=sampler,
                            num_workers=8,
                            pin_memory=True)

        # Set test accuracy recorder
        test_acc_record = np.zeros((600, ))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(
                torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_acc' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Generate labels
        label = torch.arange(self.args.way).repeat(self.args.val_query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        # Start meta-test
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = self.args.way * self.args.shot
            data_shot, data_query = data[:k], data[k:]
            logits = self.model((data_shot, label_shot, data_query))
            acc = count_acc(logits, label)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc
            if i % 100 == 0:
                print('batch {}: {:.2f}({:.2f})'.format(
                    i,
                    ave_acc.item() * 100, acc * 100))

        # Calculate the confidence interval, update the logs
        m, pm = compute_confidence_interval(test_acc_record)
        print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(
            trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item()))
        print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
Пример #7
0
    def train(self):
        """The function for the meta-train phase."""

        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Generate the labels for train set of the episodes
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-train updates
            label = torch.arange(self.args.way).repeat(self.args.train_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                # Output logits for model
                logits = self.model((data_shot, label_shot, data_query))
                # Calculate meta-train loss
                loss = F.cross_entropy(logits, label)
                # Calculate meta-train accuracy
                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f}'.format(
                        epoch, loss.item(), acc))

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update learning rate
            self.lr_scheduler.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # Start validation for this epoch, set model to eval mode
            self.model.eval()

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-val for this epoch
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val Acc={:.4f}'.format(
                    trlog['max_acc_epoch'], trlog['max_acc']))
            # Run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)

                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
            # Print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager))

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))

        writer.close()
Пример #8
0
    def eval(self):
        """The function for the meta-eval phase."""

        # Load the logs
        def multiclass_roc_auc_score(y_test, y_pred, average="macro"):
            lb = LabelBinarizer()
            lb.fit(y_test)
            y_test = lb.transform(y_test)
            y_pred = lb.transform(y_pred)
            return roc_auc_score(y_test, y_pred, average=average)

        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        test_set = Dataset('test', self.args)
        sampler = CategoriesSampler(test_set.label, 20, self.args.way,
                                    self.args.shot + self.args.val_query)
        loader = DataLoader(test_set,
                            batch_sampler=sampler,
                            num_workers=8,
                            pin_memory=True)

        # Set test accuracy recorder
        test_acc_record = np.zeros((20, ))
        test_f1_record = np.zeros((20, ))
        test_auc_record = np.zeros((20, ))
        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(
                torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_acc' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Generate labels
        label = torch.arange(self.args.way).repeat(self.args.val_query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        Y = label.data.cpu().numpy()
        # Start meta-test
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = self.args.way * self.args.shot
            data_shot, data_query = data[:k], data[k:]
            logits = self.model((data_shot, label_shot, data_query))
            acc = count_acc(logits, label)
            logits = logits.data.cpu().numpy()
            predicted = np.argmax(logits, axis=1)
            f1 = f1_score(Y, predicted, average='macro')
            auc = multiclass_roc_auc_score(Y, predicted)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc
            test_f1_record[i - 1] = f1
            test_auc_record[i - 1] = auc

            if i % 100 == 0:
                print('batch {}: {:.2f}({:.2f})'.format(
                    i,
                    ave_acc.item() * 100, acc * 100))

        # Calculate the confidence interval, update the logs
        m, pm = compute_confidence_interval(test_acc_record)
        f1_m, f1_pm = compute_confidence_interval(test_f1_record)
        auc_m, auc_pm = compute_confidence_interval(test_auc_record)

        print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(
            trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item()))
        print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
        print('Test f1 {:.4f} + {:.4f}'.format(f1_m, f1_pm))
        print('Test auc {:.4f} + {:.4f}'.format(auc_m, auc_pm))