Beispiel #1
0
class MetaTrainer(object):
    """The class that contains the code for the meta-train phase and meta-eval phase."""
    def __init__(self, args):
        # Set the folder to save the records and checkpoints
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        meta_base_dir = osp.join(log_base_dir, 'meta')
        if not osp.exists(meta_base_dir):
            os.mkdir(meta_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
        save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \
            '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \
            '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \
            '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \
            '_stepsize' + str(args.step_size) + '_' + args.meta_label
        args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load meta-train set
        self.trainset = Dataset('train', self.args)
        self.train_sampler = CategoriesSampler(
            self.trainset.label, self.args.num_batch, self.args.way,
            self.args.shot + self.args.train_query)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_sampler=self.train_sampler,
                                       num_workers=8,
                                       pin_memory=True)

        # Load meta-val set
        self.valset = Dataset('val', self.args)
        self.val_sampler = CategoriesSampler(
            self.valset.label, 600, self.args.way,
            self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=8,
                                     pin_memory=True)

        # Build meta-transfer learning model
        self.model = MtlLearner(self.args)

        # Set optimizer
        self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \
            {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1)
        # Set learning rate scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=self.args.step_size,
            gamma=self.args.gamma)

        # load pretrained model without FC classifier
        self.model_dict = self.model.state_dict()
        if self.args.init_weights is not None:
            pretrained_dict = torch.load(self.args.init_weights)['params']
        else:
            pre_base_dir = osp.join(log_base_dir, 'pre')
            pre_save_path1 = '_'.join([args.dataset, args.model_type])
            pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
                str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
            pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2
            pretrained_dict = torch.load(osp.join(pre_save_path,
                                                  'max_acc.pth'))['params']
        pretrained_dict = {
            'encoder.' + k: v
            for k, v in pretrained_dict.items()
        }
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in self.model_dict
        }
        print(pretrained_dict.keys())
        self.model_dict.update(pretrained_dict)
        self.model.load_state_dict(self.model_dict)

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def train(self):
        """The function for the meta-train phase."""

        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Generate the labels for train set of the episodes
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-train updates
            label = torch.arange(self.args.way).repeat(self.args.train_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                # Output logits for model
                logits = self.model((data_shot, label_shot, data_query))
                # Calculate meta-train loss
                loss = F.cross_entropy(logits, label)
                # Calculate meta-train accuracy
                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f}'.format(
                        epoch, loss.item(), acc))

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # Start validation for this epoch, set model to eval mode
            self.model.eval()

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-val for this epoch
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val Acc={:.4f}'.format(
                    trlog['max_acc_epoch'], trlog['max_acc']))
            # Run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)

                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
            # Print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager))

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))

        writer.close()

    def eval(self):
        """The function for the meta-eval phase."""
        # Load the logs
        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        test_set = Dataset('test', self.args)
        sampler = CategoriesSampler(test_set.label, 600, self.args.way,
                                    self.args.shot + self.args.val_query)
        loader = DataLoader(test_set,
                            batch_sampler=sampler,
                            num_workers=8,
                            pin_memory=True)

        # Set test accuracy recorder
        test_acc_record = np.zeros((600, ))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(
                torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_acc' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Generate labels
        label = torch.arange(self.args.way).repeat(self.args.val_query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        # Start meta-test
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = self.args.way * self.args.shot
            data_shot, data_query = data[:k], data[k:]
            logits = self.model((data_shot, label_shot, data_query))
            acc = count_acc(logits, label)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc
            if i % 100 == 0:
                print('batch {}: {:.2f}({:.2f})'.format(
                    i,
                    ave_acc.item() * 100, acc * 100))

        # Calculate the confidence interval, update the logs
        m, pm = compute_confidence_interval(test_acc_record)
        print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(
            trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item()))
        print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
Beispiel #2
0
class PreTrainer(object):
    """The class that contains the code for the pretrain phase."""
    def __init__(self, args):
        # Set the folder to save the records and checkpoints
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        pre_base_dir = osp.join(log_base_dir, 'pre')
        if not osp.exists(pre_base_dir):
            os.mkdir(pre_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type])
        save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
            str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
        args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load pretrain set
        self.trainset = Dataset('train', self.args, train_aug=True)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_size=args.pre_batch_size,
                                       shuffle=True,
                                       num_workers=8,
                                       pin_memory=True)

        # Load meta-val set
        self.valset = Dataset('val', self.args)
        self.val_sampler = CategoriesSampler(
            self.valset.label, 600, self.args.way,
            self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=8,
                                     pin_memory=True)

        # Set pretrain class number
        num_class_pretrain = self.trainset.num_class

        # Build pretrain model
        self.model = MtlLearner(self.args,
                                mode='pre',
                                num_cls=num_class_pretrain)

        # Set optimizer
        self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}, \
            {'params': self.model.pre_fc.parameters(), 'lr': self.args.pre_lr}], \
                momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay)
        # Set learning rate scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \
            gamma=self.args.pre_gamma)

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """
        torch.save(dict(params=self.model.encoder.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def train(self):
        """The function for the pre-train phase."""

        # Set the pretrain log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Start pretrain
        for epoch in range(1, self.args.pre_max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            self.model.mode = 'pre'
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                label = batch[1]
                if torch.cuda.is_available():
                    label = label.type(torch.cuda.LongTensor)
                else:
                    label = label.type(torch.LongTensor)
                # Output logits for model
                logits = self.model(data)
                # Calculate train loss
                loss = F.cross_entropy(logits, label)
                # Calculate train accuracy
                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f}'.format(
                        epoch, loss.item(), acc))

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # Start validation for this epoch, set model to eval mode
            self.model.eval()
            self.model.mode = 'preval'

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)
            label_shot = torch.arange(self.args.way).repeat(self.args.shot)
            if torch.cuda.is_available():
                label_shot = label_shot.type(torch.cuda.LongTensor)
            else:
                label_shot = label_shot.type(torch.LongTensor)

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val acc={:.4f}'.format(
                    trlog['max_acc_epoch'], trlog['max_acc']))
            # Run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
            # Print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager))

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
        writer.close()
Beispiel #3
0
class PreTrainer(object):
    def __init__(self, args):
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        pre_base_dir = osp.join(log_base_dir, 'pre')
        if not osp.exists(pre_base_dir):
            os.mkdir(pre_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type])
        save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
            str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
        args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        self.args = args

        if self.args.dataset == 'MiniImageNet':
            from dataloader.mini_imagenet import MiniImageNet as Dataset
        elif self.args.dataset == 'TieredImageNet':
            from dataloader.tiered_imagenet import TieredImageNet as Dataset
        elif self.args.dataset == 'FC100':
            from dataloader.fewshotcifar import FewshotCifar as Dataset
        else:
            raise ValueError('Please set correct dataset.')

        self.trainset = Dataset('train', self.args, train_aug=True)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_size=args.pre_batch_size,
                                       shuffle=True,
                                       num_workers=8,
                                       pin_memory=True)

        self.valset = Dataset('test', self.args)
        self.val_sampler = CategoriesSampler(
            self.valset.label, 600, self.args.way,
            self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=8,
                                     pin_memory=True)

        num_class_pretrain = self.trainset.num_class

        self.model = MtlLearner(self.args,
                                mode='pre',
                                num_cls=num_class_pretrain)

        self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}, \
            {'params': self.model.pre_fc.parameters(), 'lr': self.args.pre_lr}], \
                momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay)

        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \
            gamma=self.args.pre_gamma)

        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

    def save_model(self, name):
        torch.save(dict(params=self.model.encoder.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def train(self):
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        timer = Timer()
        global_count = 0
        writer = SummaryWriter(comment=self.args.save_path)

        for epoch in range(1, self.args.pre_max_epoch + 1):
            self.lr_scheduler.step()
            self.model.train()
            self.model.mode = 'pre'
            tl = Averager()
            ta = Averager()

            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                label = batch[1]
                if torch.cuda.is_available():
                    label = label.type(torch.cuda.LongTensor)
                else:
                    label = label.type(torch.LongTensor)
                logits = self.model(data)
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f}'.format(
                        epoch, loss.item(), acc))

                tl.add(loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            tl = tl.item()
            ta = ta.item()

            self.model.eval()
            self.model.mode = 'preval'

            vl = Averager()
            va = Averager()

            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)
            label_shot = torch.arange(self.args.way).repeat(self.args.shot)
            if torch.cuda.is_available():
                label_shot = label_shot.type(torch.cuda.LongTensor)
            else:
                label_shot = label_shot.type(torch.LongTensor)

            print('Best Epoch {}, Best Val acc={:.4f}'.format(
                trlog['max_acc_epoch'], trlog['max_acc']))
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                vl.add(loss.item())
                va.add(acc)

            vl = vl.item()
            va = va.item()
            writer.add_scalar('data/val_loss', float(vl), epoch)
            writer.add_scalar('data/val_acc', float(va), epoch)
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                epoch, vl, va))

            if va > trlog['max_acc']:
                trlog['max_acc'] = va
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            if epoch % 20 == 0:
                self.save_model('epoch' + str(epoch))

            trlog['train_loss'].append(tl)
            trlog['train_acc'].append(ta)
            trlog['val_loss'].append(vl)
            trlog['val_acc'].append(va)

            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch > self.args.pre_max_epoch - 2:
                self.save_model('epoch-last')
                torch.save(
                    self.optimizer.state_dict(),
                    osp.join(self.args.save_path, 'optimizer_latest.pth'))

            print('Running Time: {}, Estimated Time: {}'.format(
                timer.measure(), timer.measure(epoch / self.args.max_epoch)))
        writer.close()
Beispiel #4
0
class MetaTrainer(object):
    """The class that contains the code for the meta-train phase and meta-eval phase."""
    def __init__(self, args):
        # Set the folder to save the records and checkpoints
        save_image_dir = '../results1/'
        if not osp.exists(save_image_dir):
            os.mkdir(save_image_dir)

        log_base_dir = '../logs1/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        meta_base_dir = osp.join(log_base_dir, 'meta')
        if not osp.exists(meta_base_dir):
            os.mkdir(meta_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
        save_path2 = '_mtype' + str(args.mtype) + '_shot' + str(args.train_query) + '_way' + str(args.way) + '_query' + str(args.train_query) + \
            '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr' + str(args.meta_lr) + \
            '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \
            '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \
            '_stepsize' + str(args.step_size) + '_' + args.meta_label
        args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
        args.save_image_dir = save_image_dir
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load meta-train set
        self.trainset = Dataset('train', self.args)
        self.train_sampler = CategoriesSampler(self.trainset.labeln,
                                               self.args.num_batch,
                                               self.args.way + 1,
                                               self.args.train_query,
                                               self.args.test_query)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_sampler=self.train_sampler,
                                       num_workers=8,
                                       pin_memory=True)

        # Load meta-val set
        if (self.args.valdata == 'Yes'):
            self.valset = Dataset('val', self.args)
            self.val_sampler = CategoriesSampler(self.valset.labeln,
                                                 self.args.num_batch,
                                                 self.args.way + 1,
                                                 self.args.train_query,
                                                 self.args.test_query)
            self.val_loader = DataLoader(dataset=self.valset,
                                         batch_sampler=self.val_sampler,
                                         num_workers=8,
                                         pin_memory=True)

        # Build meta-transfer learning model
        self.model = MtlLearner(self.args)
        self.CD = CE_DiceLoss()
        self.FL = FocalLoss()
        self.LS = LovaszSoftmax()

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

        # Set optimizer
        self.optimizer = torch.optim.Adam([{
            'params':
            filter(lambda p: p.requires_grad, self.model.encoder.parameters())
        }],
                                          lr=self.args.meta_lr)

        # Set learning rate scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=self.args.step_size,
            gamma=self.args.gamma)

        # load pretrained model
        # Path should nbe changed accordingly

        self.model.load_state_dict(
            torch.load(osp.join(self.args.save_path,
                                'epoch24' + '.pth'))['params'])
        self.optimizer.load_state_dict(
            torch.load(osp.join(self.args.save_path,
                                'epoch24' + '_o.pth'))['params_o'])
        self.lr_scheduler.load_state_dict(
            torch.load(osp.join(self.args.save_path,
                                'epoch24' + '_s.pth'))['params_s'])

        self.model_dict = self.model.state_dict()
        self.optimizer_dict = self.optimizer.state_dict()
        self.lr_scheduler_dict = self.lr_scheduler.state_dict()

        #Total Model Parameters
        pytorch_total_params = sum(p.numel() for p in self.model.parameters()
                                   if p.requires_grad)
        print("Total Trainable Parameters in the Model: " +
              str(pytorch_total_params))

    def _reset_metrics(self):
        self.total_inter, self.total_union = 0, 0
        self.total_correct, self.total_label = 0, 0

    def _update_seg_metrics(self, correct, labeled, inter, union):
        self.total_correct += correct
        self.total_label += labeled
        self.total_inter += inter
        self.total_union += union

    def _get_seg_metrics(self, n_class):
        self.n_class = n_class
        pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
        IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
        mIoU = IoU.mean()
        return {
            "Pixel_Accuracy": np.round(pixAcc, 3),
            "Mean_IoU": np.round(mIoU, 3),
            "Class_IoU": dict(zip(range(self.n_class), np.round(IoU, 3)))
        }

    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))
        torch.save(dict(params_o=self.optimizer.state_dict()),
                   osp.join(self.args.save_path, name + '_o.pth'))
        torch.save(dict(params_s=self.lr_scheduler.state_dict()),
                   osp.join(self.args.save_path, name + '_s.pth'))

    def train(self):
        """The function for the meta-train phase."""

        # Set the meta-train log
        #Change when resuming training
        initial_epoch = 25

        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['train_acc'] = []
        trlog['train_iou'] = []

        # Set the meta-val log
        trlog['val_loss'] = []
        trlog['val_acc'] = []
        trlog['val_iou'] = []

        trlog['max_iou'] = 0.2856
        trlog['max_iou_epoch'] = 4

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        K = self.args.way + 1  #included Background as class
        N = self.args.train_query
        Q = self.args.test_query

        # Start meta-train
        for epoch in range(initial_epoch, self.args.max_epoch + 1):
            print(
                '----------------------------------------------------------------------------------------------------------------------------------------------------------'
            )

            # Update learning rate
            self.lr_scheduler.step()

            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()
            train_iou_averager = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)

            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, labels, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    labels = batch[1]

                #print(data.shape)
                #print(labels.shape)
                p = K * N
                im_train, im_test = data[:p], data[p:]

                #Adjusting labels for each meta task
                labels = downlabel(labels, K)
                out_train, out_test = labels[:p], labels[p:]
                '''
                print(im_train.shape)
                print(im_test.shape)
                print(out_train.shape)
                print(out_test.shape)
                '''
                if (torch.cuda.is_available()):
                    im_train = im_train.cuda()
                    im_test = im_test.cuda()
                    out_train = out_train.cuda()
                    out_test = out_test.cuda()

                #Reshaping train set ouput
                Ytr = out_train.reshape(-1)
                Ytr = onehot(Ytr, K)  #One hot encoding for loss

                Yte = out_test.reshape(out_test.shape[0], -1)
                if (torch.cuda.is_available()):
                    Ytr = Ytr.cuda()
                    Yte = Yte.cuda()

                # Output logits for model
                Gte = self.model(im_train, Ytr, im_test, Yte)
                GteT = torch.transpose(Gte, 1, 2)

                # Calculate meta-train loss

                #loss = self.CD(GteT,Yte)
                loss = self.FL(GteT, Yte)
                #loss = self.LS(GteT,Yte)

                self._reset_metrics()
                # Calculate meta-train accuracy
                seg_metrics = eval_metrics(GteT, Yte, K)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(K).values()

                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(
                        epoch, loss.item(), pixAcc * 100.0, mIoU))

                # Add loss and accuracy for the averagers
                # Calculate the running averages
                train_loss_averager.add(loss.item())
                train_acc_averager.add(pixAcc)
                train_iou_averager.add(mIoU)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()
            train_iou_averager = train_iou_averager.item()

            #Adding to Tensorboard
            writer.add_scalar('data/train_loss (Meta)',
                              float(train_loss_averager), epoch)
            writer.add_scalar('data/train_acc (Meta)',
                              float(train_acc_averager) * 100.0, epoch)
            writer.add_scalar('data/train_iou (Meta)',
                              float(train_iou_averager), epoch)

            # Update best saved model if validation set is not present and save it
            if (self.args.valdata == 'No'):
                if train_iou_averager > trlog['max_iou']:
                    print("New Best!")
                    trlog['max_iou'] = train_iou_averager
                    trlog['max_iou_epoch'] = epoch
                    self.save_model('max_iou')

                # Save model every 2 epochs
                if epoch % 2 == 0:
                    self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['train_iou'].append(train_iou_averager)

            if epoch % 1 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
                print('Epoch:{}, Average Loss: {:.4f}, Average mIoU: {:.4f}'.
                      format(epoch, train_loss_averager, train_iou_averager))
            """The function for the meta-val phase."""

            if (self.args.valdata == 'Yes'):
                # Start meta-val
                # Set the model to val mode
                self.model.eval()

                # Set averager classes to record training losses and accuracies
                val_loss_averager = Averager()
                val_acc_averager = Averager()
                val_iou_averager = Averager()

                # Using tqdm to read samples from train loader
                tqdm_gen = tqdm.tqdm(self.val_loader)

                for i, batch in enumerate(tqdm_gen, 1):
                    # Update global count number
                    global_count = global_count + 1
                    if torch.cuda.is_available():
                        data, labels, _ = [_.cuda() for _ in batch]
                    else:
                        data = batch[0]
                        labels = batch[1]

                    #print(data.shape)
                    #print(labels.shape)
                    p = K * N
                    im_train, im_test = data[:p], data[p:]

                    #Adjusting labels for each meta task
                    labels = downlabel(labels, K)
                    out_train, out_test = labels[:p], labels[p:]
                    '''
                    print(im_train.shape)
                    print(im_test.shape)
                    print(out_train.shape)
                    print(out_test.shape)
                    '''
                    if (torch.cuda.is_available()):
                        im_train = im_train.cuda()
                        im_test = im_test.cuda()
                        out_train = out_train.cuda()
                        out_test = out_test.cuda()

                    #Reshaping val set ouput
                    Ytr = out_train.reshape(-1)
                    Ytr = onehot(Ytr, K)  #One hot encoding for loss

                    Yte = out_test.reshape(out_test.shape[0], -1)
                    if (torch.cuda.is_available()):
                        Ytr = Ytr.cuda()
                        Yte = Yte.cuda()

                    # Output logits for model
                    Gte = self.model(im_train, Ytr, im_test, Yte)
                    GteT = torch.transpose(Gte, 1, 2)

                    self._reset_metrics()
                    # Calculate meta-train accuracy
                    seg_metrics = eval_metrics(GteT, Yte, K)
                    self._update_seg_metrics(*seg_metrics)
                    pixAcc, mIoU, _ = self._get_seg_metrics(K).values()

                    # Print loss and accuracy for this step
                    tqdm_gen.set_description(
                        'Epoch {}, Val Loss={:.4f} Val Acc={:.4f} Val IoU={:.4f}'
                        .format(epoch, loss.item(), pixAcc * 100.0, mIoU))

                    # Add loss and accuracy for the averagers
                    # Calculate the running averages
                    val_loss_averager.add(loss.item())
                    val_acc_averager.add(pixAcc)
                    val_iou_averager.add(mIoU)

                # Update the averagers
                val_loss_averager = val_loss_averager.item()
                val_acc_averager = val_acc_averager.item()
                val_iou_averager = val_iou_averager.item()

                #Adding to Tensorboard
                writer.add_scalar('data/val_loss (Meta)',
                                  float(val_loss_averager), epoch)
                writer.add_scalar('data/val_acc (Meta)',
                                  float(val_acc_averager) * 100.0, epoch)
                writer.add_scalar('data/val_iou (Meta)',
                                  float(val_iou_averager), epoch)

                # Update best saved model
                if val_iou_averager > trlog['max_iou']:
                    print("New Best (Validation)")
                    trlog['max_iou'] = val_iou_averager
                    trlog['max_iou_epoch'] = epoch
                    self.save_model('max_iou')

                # Save model every 2 epochs
                if epoch % 2 == 0:
                    self.save_model('epoch' + str(epoch))

                # Update the logs
                trlog['val_loss'].append(val_loss_averager)
                trlog['val_acc'].append(val_acc_averager)
                trlog['val_iou'].append(val_iou_averager)

                if epoch % 1 == 0:
                    print('Running Time: {}, Estimated Time: {}'.format(
                        timer.measure(),
                        timer.measure(epoch / self.args.max_epoch)))
                    print(
                        'Epoch:{}, Average Val Loss: {:.4f}, Average Val mIoU: {:.4f}'
                        .format(epoch, val_loss_averager, val_iou_averager))

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

        print(
            '----------------------------------------------------------------------------------------------------------------------------------------------------------'
        )
        writer.close()

    def eval(self):
        """The function for the meta-evaluate (test) phase."""
        # Load the logs
        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        self.test_set = Dataset('test', self.args)
        self.sampler = CategoriesSampler(self.test_set.labeln,
                                         self.args.num_batch,
                                         self.args.way + 1,
                                         self.args.train_query,
                                         self.args.test_query)
        self.loader = DataLoader(dataset=self.test_set,
                                 batch_sampler=self.sampler,
                                 num_workers=8,
                                 pin_memory=True)

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(
                torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_iou' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy(IoU) averager
        ave_acc = Averager()

        # Start meta-test
        K = self.args.way + 1
        N = self.args.train_query
        Q = self.args.test_query

        count = 1
        for i, batch in enumerate(self.loader, 1):
            if torch.cuda.is_available():
                data, labels, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
                labels = batch[1]

            p = K * N
            im_train, im_test = data[:p], data[p:]

            #Adjusting labels for each meta task
            labels = downlabel(labels, K)
            out_train, out_test = labels[:p], labels[p:]

            if (torch.cuda.is_available()):
                im_train = im_train.cuda()
                im_test = im_test.cuda()
                out_train = out_train.cuda()
                out_test = out_test.cuda()

            #Reshaping train set ouput
            Ytr = out_train.reshape(-1)
            Ytr = onehot(Ytr, K)  #One hot encoding for loss

            Yte = out_test.reshape(out_test.shape[0], -1)

            if (torch.cuda.is_available()):
                Ytr = Ytr.cuda()
                Yte = Yte.cuda()
            # Output logits for model
            Gte = self.model(im_train, Ytr, im_test, Yte)
            GteT = torch.transpose(Gte, 1, 2)

            # Calculate meta-train accuracy
            self._reset_metrics()
            seg_metrics = eval_metrics(GteT, Yte, K)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics(K).values()

            ave_acc.add(mIoU)

            #Saving Test Image, Ground Truth Image and Predicted Image
            for j in range(K * Q):

                x1 = im_test[j].detach().cpu()
                y1 = out_test[j].detach().cpu()
                z1 = GteT[j].detach().cpu()
                z1 = torch.argmax(z1, axis=0)

                m = int(math.sqrt(z1.shape[0]))
                z2 = z1.reshape(m, m)

                x = transforms.ToPILImage()(x1).convert("RGB")
                y = Image.fromarray(decode_segmap(y1, K))
                z = Image.fromarray(decode_segmap(z2, K))

                px = self.args.save_image_dir + str(count) + 'a.jpg'
                py = self.args.save_image_dir + str(count) + 'b.png'
                pz = self.args.save_image_dir + str(count) + 'c.png'
                x.save(px)
                y.save(py)
                z.save(pz)
                count = count + 1

        # Test mIoU
        ave_acc = ave_acc.item()
        print("=============================================================")
        print('Average Test mIoU: {:.4f}'.format(ave_acc))
        print("Images Saved!")
        print("=============================================================")
Beispiel #5
0
class PreTrainer(object):
    def __init__(self, args):
        # Set the folder to save the records and checkpoints
        log_base_dir = '../logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        pre_base_dir = osp.join(log_base_dir, 'pre')
        if not osp.exists(pre_base_dir):
            os.mkdir(pre_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type])
        save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
            str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
        args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load pretrain set
        self.trainset = Dataset('train', self.args)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_size=args.pre_batch_size,
                                       shuffle=True,
                                       num_workers=8,
                                       pin_memory=True)

        # Load pre-val set
        self.valset = mDataset('val', self.args)
        self.val_sampler = CategoriesSampler(
            self.valset.labeln, self.args.num_batch, self.args.way,
            self.args.shot + self.args.val_query, self.args.shot)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=8,
                                     pin_memory=True)

        # Build pretrain model
        self.model = MtlLearner(self.args, mode='train')
        print(self.model)
        '''
        if self.args.pre_init_weights is not None:
            self.model_dict = self.model.state_dict()
            pretrained_dict = torch.load(self.args.pre_init_weights)['params']
            pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()}
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict}
            print(pretrained_dict.keys())
            self.model_dict.update(pretrained_dict)
            self.model.load_state_dict(self.model_dict)   
        '''

        self.FL = FocalLoss()
        self.CD = CE_DiceLoss()
        self.LS = LovaszSoftmax()
        # Set optimizer
        # Set optimizer
        self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}], \
                momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay)

        # Set learning rate scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \
            gamma=self.args.pre_gamma)

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """
        torch.save(dict(params=self.model.encoder.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def _reset_metrics(self):
        #self.batch_time = AverageMeter()
        #self.data_time = AverageMeter()
        #self.total_loss = AverageMeter()
        self.total_inter, self.total_union = 0, 0
        self.total_correct, self.total_label = 0, 0

    def _update_seg_metrics(self, correct, labeled, inter, union):
        self.total_correct += correct
        self.total_label += labeled
        self.total_inter += inter
        self.total_union += union

    def _get_seg_metrics(self, n_class):
        self.n_class = n_class
        pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
        IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
        mIoU = IoU.mean()
        return {
            "Pixel_Accuracy": np.round(pixAcc, 3),
            "Mean_IoU": np.round(mIoU, 3),
            "Class_IoU": dict(zip(range(self.n_class), np.round(IoU, 3)))
        }

    def train(self):
        """The function for the pre-train phase."""

        # Set the pretrain log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['train_iou'] = []
        trlog['val_iou'] = []
        trlog['max_iou'] = 0.0
        trlog['max_iou_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Start pretrain
        for epoch in range(1, self.args.pre_max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            self.model.mode = 'train'
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()
            train_iou_averager = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)

            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, label = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    label = batch[1]

                # Output logits for model
                logits = self.model(data)
                # Calculate train loss
                # CD loss is modified in the whole project to incorporate ony Cross Entropy loss. Modify as per requirement.
                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits, label)

                # Calculate train accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label,
                                           self.args.num_classes)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(
                    self.args.num_classes).values()

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(pixAcc)
                train_iou_averager.add(mIoU)

                # Print loss and accuracy till this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f} IOU={:.4f}'.format(
                        epoch, train_loss_averager.item(),
                        train_acc_averager.item() * 100.0,
                        train_iou_averager.item()))

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()
            train_iou_averager = train_iou_averager.item()

            writer.add_scalar('data/train_loss(Pre)',
                              float(train_loss_averager), epoch)
            writer.add_scalar('data/train_acc(Pre)',
                              float(train_acc_averager) * 100.0, epoch)
            writer.add_scalar('data/train_iou (Pre)',
                              float(train_iou_averager), epoch)

            print(
                'Epoch {}, Train: Loss={:.4f}, Acc={:.4f}, IoU={:.4f}'.format(
                    epoch, train_loss_averager, train_acc_averager * 100.0,
                    train_iou_averager))

            # Start validation for this epoch, set model to eval mode
            self.model.eval()
            self.model.mode = 'val'

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()
            val_iou_averager = Averager()

            # Print previous information
            if epoch % 1 == 0:
                print('Best Val Epoch {}, Best Val IoU={:.4f}'.format(
                    trlog['max_iou_epoch'], trlog['max_iou']))

            # Run validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, labels, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    label = labels[0]
                p = self.args.way * self.args.shot
                data_shot, data_query = data[:p], data[p:]
                label_shot, label = labels[:p], labels[p:]

                par = data_shot, label_shot, data_query
                logits = self.model(par)
                # Calculate preval loss

                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits, label)

                # Calculate val accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label, self.args.way)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()

                val_loss_averager.add(loss.item())
                val_acc_averager.add(pixAcc)
                val_iou_averager.add(mIoU)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            val_iou_averager = val_iou_averager.item()

            writer.add_scalar('data/val_loss(Pre)', float(val_loss_averager),
                              epoch)
            writer.add_scalar('data/val_acc(Pre)',
                              float(val_acc_averager) * 100.0, epoch)
            writer.add_scalar('data/val_iou (Pre)', float(val_iou_averager),
                              epoch)

            # Print loss and accuracy for this epoch
            print('Epoch {}, Val: Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager * 100.0,
                val_iou_averager))

            # Update best saved model
            if val_iou_averager > trlog['max_iou']:
                trlog['max_iou'] = val_iou_averager
                trlog['max_iou_epoch'] = epoch
                print("model saved in max_iou")
                self.save_model('max_iou')

            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)
            trlog['train_iou'].append(train_iou_averager)
            trlog['val_iou'].append(val_iou_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 1 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
        writer.close()
Beispiel #6
0
class PreTrainer(object):
    """The class that contains the code for the pretrain phase."""
    def __init__(self, args):
        # Set the folder to save the records and checkpoints
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        pre_base_dir = osp.join(log_base_dir, 'pre')
        if not osp.exists(pre_base_dir):
            os.mkdir(pre_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type])
        save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
            str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
        args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load pretrain set
        self.trainset = Dataset('train', self.args, train_aug=False)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_size=args.pre_batch_size,
                                       shuffle=True,
                                       num_workers=8,
                                       pin_memory=True)

        # Load meta-val set
        self.valset = Dataset('val', self.args)
        self.val_sampler = CategoriesSampler(
            self.valset.label, 20, self.args.way,
            self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=8,
                                     pin_memory=True)

        # Set pretrain class number
        num_class_pretrain = self.trainset.num_class

        # Build pretrain model
        self.model = MtlLearner(self.args,
                                mode='pre',
                                num_cls=num_class_pretrain)
        #self.model=self.model.float()
        # Set optimizer
        params = list(self.model.encoder.parameters()) + list(
            self.model.pre_fc.parameters())
        self.optimizer = optim.Adam(params)

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """
        torch.save(dict(params=self.model.encoder.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def train(self):
        """The function for the pre-train phase."""

        # Set the pretrain log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Start pretrain
        for epoch in range(1, self.args.pre_max_epoch + 1):
            # Set the model to train mode

            print('Epoch {}'.format(epoch))
            self.model.train()
            self.model.mode = 'pre'
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Using tqdm to read samples from train loader

            tqdm_gen = tqdm.tqdm(self.train_loader)
            #for i, batch in enumerate(self.train_loader):
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                label = batch[1]
                if torch.cuda.is_available():
                    label = label.type(torch.cuda.LongTensor)
                else:
                    label = label.type(torch.LongTensor)
                logits = self.model(data)
                loss = F.cross_entropy(logits, label)
                # Calculate train accuracy
                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)
                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # start the original evaluation
            self.model.eval()
            self.model.mode = 'origval'

            _, valid_results = self.val_orig(self.valset.X_val,
                                             self.valset.y_val)
            print('validation accuracy ', valid_results[0])

            # Start validation for this epoch, set model to eval mode
            self.model.eval()
            self.model.mode = 'preval'

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)
            label_shot = torch.arange(self.args.way).repeat(self.args.shot)
            if torch.cuda.is_available():
                label_shot = label_shot.type(torch.cuda.LongTensor)
            else:
                label_shot = label_shot.type(torch.LongTensor)

            # Run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                #data=data.float()
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
        writer.close()

    def val_orig(self, X_val, y_val):
        predicted_loss = []
        inputs = torch.from_numpy(X_val)
        labels = torch.FloatTensor(y_val * 1.0)
        inputs, labels = Variable(inputs), Variable(labels)

        results = []
        predicted = []

        self.model.eval()
        self.model.mode = 'origval'

        if torch.cuda.is_available():
            inputs = inputs.type(torch.cuda.FloatTensor)
        else:
            inputs = inputs.type(torch.FloatTensor)

        predicted = self.model(inputs)
        predicted = predicted.data.cpu().numpy()

        Y = labels.data.numpy()
        predicted = np.argmax(predicted, axis=1)
        for param in ["acc", "auc", "recall", "precision", "fmeasure"]:
            if param == 'acc':
                results.append(accuracy_score(Y, np.round(predicted)))
            if param == "recall":
                results.append(
                    recall_score(Y, np.round(predicted), average='micro'))
            if param == "fmeasure":
                precision = precision_score(Y,
                                            np.round(predicted),
                                            average='micro')
                recall = recall_score(Y, np.round(predicted), average='micro')
                results.append(2 * precision * recall / (precision + recall))

        return predicted, results
Beispiel #7
0
class MetaTrainer(object):
    """The class that contains the code for the meta-train phase and meta-eval phase."""
    def __init__(self, args):
        param = configs.__dict__[args.config]()
        args.shot = param.shot
        args.test = param.test
        args.debug = param.debug
        args.deconfound = param.deconfound
        args.meta_label = param.meta_label
        args.init_weights = param.init_weights
        self.test_iter = param.test_iter
        args.param = param
        pprint(vars(args))

        # Set the folder to save the records and checkpoints
        log_base_dir = '/data2/yuezhongqi/Model/mtl/logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        meta_base_dir = osp.join(log_base_dir, 'meta')
        if not osp.exists(meta_base_dir):
            os.mkdir(meta_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
        save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \
            '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \
            '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \
            '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \
            '_stepsize' + str(args.step_size) + '_' + args.meta_label
        args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load meta-train set
        self.trainset = Dataset('train',
                                self.args,
                                dataset=self.args.param.dataset,
                                train_aug=False)
        num_workers = 8
        if args.debug:
            num_workers = 0
        self.train_sampler = CategoriesSampler(
            self.trainset.label, self.args.num_batch, self.args.way,
            self.args.shot + self.args.train_query)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_sampler=self.train_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)

        # Load meta-val set
        self.valset = Dataset('val',
                              self.args,
                              dataset=self.args.param.dataset,
                              train_aug=False)
        self.val_sampler = CategoriesSampler(
            self.valset.label, self.test_iter, self.args.way,
            self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=num_workers,
                                     pin_memory=True)

        # Build meta-transfer learning model
        self.model = MtlLearner(self.args)

        # load pretrained model without FC classifier
        self.model.load_pretrain_weight(self.args.init_weights)
        '''
        self.model_dict = self.model.state_dict()
        if self.args.init_weights is not None:
            pretrained_dict = torch.load(self.args.init_weights)['params']
        else:
            pre_base_dir = osp.join(log_base_dir, 'pre')
            pre_save_path1 = '_'.join([args.dataset, args.model_type])
            pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
                str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
            pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2
            pretrained_dict = torch.load(osp.join(pre_save_path, 'max_acc.pth'))['params']
        pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()}
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict}
        print(pretrained_dict.keys())
        self.model_dict.update(pretrained_dict)
        self.model.load_state_dict(self.model_dict)
        '''

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()
            if self.args.param.model == "wideres":
                print("Using Parallel")
                self.model.encoder = torch.nn.DataParallel(
                    self.model.encoder).cuda()

        # Set optimizer
        self.optimizer = torch.optim.Adam(
            [{
                'params':
                filter(lambda p: p.requires_grad,
                       self.model.encoder.parameters())
            }, {
                'params': self.model.base_learner.parameters(),
                'lr': self.args.meta_lr2
            }],
            lr=self.args.meta_lr1)
        # Set learning rate scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=self.args.step_size,
            gamma=self.args.gamma)

        if not self.args.deconfound:
            self.criterion = torch.nn.CrossEntropyLoss().cuda()
        else:
            self.criterion = torch.nn.NLLLoss().cuda()

        # Enable evaluation with Cross
        if args.cross:
            args.param.dataset = "cross"

    def write_output_message(self, message, file_name=None):
        if file_name is None:
            file_name = "results"
        # output_file = os.path.join(self.args.save_path, "results.txt")
        output_file = os.path.join("outputs", file_name + ".txt")
        with open(output_file, "a") as f:
            f.write(message + "\n")

    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def train(self):
        """The function for the meta-train phase."""

        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Generate the labels for train set of the episodes
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-train updates
            label = torch.arange(self.args.way).repeat(self.args.train_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                # Output logits for model
                logits = self.model((data_shot, label_shot, data_query, False))
                # Calculate meta-train loss
                loss = self.criterion(logits, label)
                # Calculate meta-train accuracy
                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f}'.format(
                        epoch, loss.item(), acc))

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # Start validation for this epoch, set model to eval mode
            self.model.eval()

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-val for this epoch
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val Acc={:.4f}'.format(
                    trlog['max_acc_epoch'], trlog['max_acc']))
            # Run meta-validation
            print_freq = int(self.test_iter / 5)

            if epoch > 0:
                for i, batch in enumerate(self.val_loader, 1):
                    if torch.cuda.is_available():
                        data, _ = [_.cuda() for _ in batch]
                    else:
                        data = batch[0]
                    p = self.args.shot * self.args.way
                    data_shot, data_query = data[:p], data[p:]
                    logits = self.model(
                        (data_shot, label_shot, data_query, True))
                    # loss = F.cross_entropy(logits, label)
                    if not self.args.deconfound:
                        loss = F.cross_entropy(logits, label)
                    else:
                        loss = F.nll_loss(logits, label)

                    acc = count_acc(logits, label)

                    val_loss_averager.add(loss.item())
                    val_acc_averager.add(acc)

                    if i % print_freq == 0:
                        # Update validation averagers
                        val_loss_averager_item = val_loss_averager.item()
                        val_acc_averager_item = val_acc_averager.item()
                        # Write the tensorboardX records
                        writer.add_scalar('data/val_loss',
                                          float(val_loss_averager_item), epoch)
                        writer.add_scalar('data/val_acc',
                                          float(val_acc_averager_item), epoch)
                        # Print loss and accuracy for this epoch
                        print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                            epoch, val_loss_averager_item,
                            val_acc_averager_item))

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
            # Print loss and accuracy for this epoch
            msg = 'Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager)
            print(msg)
            self.write_output_message(msg)

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))
        writer.close()

    def eval(self):
        """The function for the meta-eval phase."""
        # Load the logs
        # trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        num_workers = 8
        if self.args.debug:
            num_workers = 0

        self.test_iter = 2000
        # Load meta-test set
        test_set = Dataset('test',
                           self.args,
                           dataset=self.args.param.dataset,
                           train_aug=False)
        sampler = CategoriesSampler(test_set.label, self.test_iter,
                                    self.args.way,
                                    self.args.shot + self.args.val_query)
        loader = DataLoader(test_set,
                            batch_sampler=sampler,
                            num_workers=num_workers,
                            pin_memory=True)

        # Set test accuracy recorder
        test_acc_record = np.zeros((self.test_iter, ))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(
                torch.load(self.args.eval_weights)['params'])
        else:
            # Load according to config file
            args = self.args
            base_path = "/data2/yuezhongqi/Model/ifsl/mtl"
            if args.param.dataset == "tiered":
                add_path = "tiered_"
            else:
                add_path = ""
            if args.param.model == "ResNet10":
                add_path += "resnet_"
            elif args.param.model == "wideres":
                add_path += "wrn_"
            elif "baseline" in args.config:
                add_path += "baseline_"
            else:
                add_path += "edsplit_"
            add_path += str(args.param.shot)
            self.add_path = add_path
            self.model.load_state_dict(
                torch.load(osp.join(base_path, add_path + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()
        # Generate labels
        label = torch.arange(self.args.way).repeat(self.args.val_query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        hacc = Hacc()
        # Start meta-test
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = self.args.way * self.args.shot
            data_shot, data_query = data[:k], data[k:]
            logits = self.model((data_shot, label_shot, data_query, True))
            acc = count_acc(logits, label)
            hardness, correct = get_hardness_correct(logits, label_shot, label,
                                                     data_shot, data_query,
                                                     self.model.pretrain)
            ave_acc.add(acc)
            hacc.add_data(hardness, correct)
            test_acc_record[i - 1] = acc
            if i % 100 == 0:
                #print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
                print("Average acc:{:.4f}, Average hAcc:{:.4f}".format(
                    ave_acc.item(), hacc.get_topk_hard_acc()))

        # Modify add path to generate test case name:
        test_case_name = self.add_path
        if self.args.cross:
            test_case_name += "_cross"
        # Calculate the confidence interval, update the logs
        m, pm = compute_confidence_interval(test_acc_record)
        msg = test_case_name + ' Test Acc {:.4f} +- {:.4f}, hAcc {:.4f}'.format(
            ave_acc.item() * 100, pm * 100, hacc.get_topk_hard_acc())
        print(msg)
        self.write_output_message(msg, test_case_name)

        if self.args.save_hacc:
            print("Saving hacc!")
            pickle.dump(hacc, open("hacc/" + test_case_name, "wb"))
        print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
Beispiel #8
0
class MetaTrainer(object):
    def __init__(self, args):
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        meta_base_dir = osp.join(log_base_dir, 'meta')
        if not osp.exists(meta_base_dir):
            os.mkdir(meta_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
        save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr' + str(args.lr) + '_lrbase' + str(args.lr_base) + '_lrc' + str(args.lr_combination) + '_lrch' + str(args.lr_combination_hyperprior) + '_lrbs' + str(args.lr_basestep) + '_lrbsh' + str(args.lr_basestep_hyperprior) + '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + '_csw' + str(args.hyperprior_combination_softweight) + '_cbsw' + str(args.hyperprior_basestep_softweight) + '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + '_stepsize' + str(args.step_size) + '_' + args.label
        args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        self.args = args

        if self.args.dataset == 'MiniImageNet':
            from dataloader.mini_imagenet import MiniImageNet as Dataset
        elif self.args.dataset == 'TieredImageNet':
            from dataloader.tiered_imagenet import TieredImageNet as Dataset
        elif self.args.dataset == 'FC100':
            from dataloader.fewshotcifar import FewshotCifar as Dataset
        else:
            raise ValueError('Non-supported Dataset.')

        self.trainset = Dataset('train', self.args)
        self.train_sampler = CategoriesSampler(self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query)
        self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True)

        self.valset = Dataset('val', self.args)
        self.val_sampler = CategoriesSampler(self.valset.label, 3000, self.args.way, self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True)
        
        self.model = MtlLearner(self.args)

        new_para = filter(lambda p: p.requires_grad, self.model.encoder.parameters())
        self.optimizer = torch.optim.Adam([{'params': new_para}, {'params': self.model.base_learner.parameters(), 'lr': self.args.lr_base}, {'params': self.model.get_hyperprior_combination_initialization_vars(), 'lr': self.args.lr_combination}, {'params': self.model.get_hyperprior_combination_mapping_vars(), 'lr': self.args.lr_combination_hyperprior}, {'params': self.model.get_hyperprior_basestep_initialization_vars(), 'lr': self.args.lr_basestep}, {'params': self.model.get_hyperprior_stepsize_mapping_vars(), 'lr': self.args.lr_basestep_hyperprior}], lr=self.args.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma)        
        
        self.model_dict = self.model.state_dict()
        if self.args.init_weights is not None:
            pretrained_dict = torch.load(self.args.init_weights)['params']
            pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()}
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict}
            print(pretrained_dict.keys())
            self.model_dict.update(pretrained_dict) 

        self.model.load_state_dict(self.model_dict)    
        
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()
        
    def save_model(self, name):
        torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth'))       

    def train(self):
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        timer = Timer()
        global_count = 0
        writer = SummaryWriter(logdir=self.args.save_path)
        
        for epoch in range(1, self.args.max_epoch + 1):
            self.lr_scheduler.step()
            self.model.train()
            tl = Averager()
            ta = Averager()

            label = torch.arange(self.args.way).repeat(self.args.train_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            label_shot = torch.arange(self.args.way).repeat(self.args.shot)
            if torch.cuda.is_available():
                label_shot = label_shot.type(torch.cuda.LongTensor)
            else:
                label_shot = label_shot.type(torch.LongTensor)
                
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits, combination_list, basestep_list = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                writer.add_scalar('combination_value/0', float(combination_list[0][0]), global_count)
                writer.add_scalar('combination_value/24', float(combination_list[24][0]), global_count)
                writer.add_scalar('combination_value/49', float(combination_list[49][0]), global_count)
                writer.add_scalar('combination_value/74', float(combination_list[74][0]), global_count)
                writer.add_scalar('combination_value/99', float(combination_list[99][0]), global_count)

                writer.add_scalar('basestep_value/0', float(basestep_list[0][0]), global_count)
                writer.add_scalar('basestep_value/24', float(basestep_list[24][0]), global_count)
                writer.add_scalar('basestep_value/49', float(basestep_list[49][0]), global_count)
                writer.add_scalar('basestep_value/74', float(basestep_list[74][0]), global_count)
                writer.add_scalar('basestep_value/99', float(basestep_list[99][0]), global_count)
                tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f}'.format(epoch, loss.item(), acc))

                tl.add(loss.item())
                ta.add(acc)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            tl = tl.item()
            ta = ta.item()

            self.model.eval()

            vl = Averager()
            va = Averager()

            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)
            label_shot = torch.arange(self.args.way).repeat(self.args.shot)
            if torch.cuda.is_available():
                label_shot = label_shot.type(torch.cuda.LongTensor)
            else:
                label_shot = label_shot.type(torch.LongTensor)
                
            print('Best Epoch {}, Best Val Acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc']))
            tqdm_gen1 = tqdm.tqdm(self.val_loader)
            for i, batch in enumerate(tqdm_gen1, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                logits, _, _ = self.model((data_shot, label_shot, data_query))
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                vl.add(loss.item())
                va.add(acc)

            vl = vl.item()
            va = va.item()
            writer.add_scalar('data/val_loss', float(vl), epoch)
            writer.add_scalar('data/val_acc', float(va), epoch)       
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(epoch, vl, va))

            if va > trlog['max_acc']:
                trlog['max_acc'] = va
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')
            if epoch % 10 == 0:
                self.save_model('epoch'+str(epoch))

            trlog['train_loss'].append(tl)
            trlog['train_acc'].append(ta)
            trlog['val_loss'].append(vl)
            trlog['val_acc'].append(va)

            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            self.save_model('epoch-last')

        writer.close()
Beispiel #9
0
class MetaTrainer(object):
    """The class that contains the code for the meta-train phase and meta-eval phase."""
    def __init__(self, args):
        # Set the folder to save the records and checkpoints
        save_image_dir='../results/'
        if not osp.exists(save_image_dir):
            os.mkdir(save_image_dir)        
        
        log_base_dir = '../logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        meta_base_dir = osp.join(log_base_dir, 'meta')
        if not osp.exists(meta_base_dir):
            os.mkdir(meta_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
        save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \
            '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \
            '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \
            '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \
            '_stepsize' + str(args.step_size) + '_' + args.meta_label
        args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
        args.save_image_dir=save_image_dir
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load meta-train set
        self.trainset = mDataset('meta', self.args)
        self.train_sampler = CategoriesSampler(self.trainset.labeln, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query,self.args.shot)
        self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True)

        # Load meta-val set
        self.valset = mDataset('val', self.args)
        self.val_sampler = CategoriesSampler(self.valset.labeln, self.args.num_batch, self.args.way, self.args.shot + self.args.val_query,self.args.shot)
        self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True)
        
        # Build meta-transfer learning model
        self.model = MtlLearner(self.args)
        self.FL=FocalLoss()
        self.CD=CE_DiceLoss()
        self.LS=LovaszSoftmax()
        
        # Set optimizer 
        self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \
            {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1)
        # Set learning rate scheduler 
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma)        
        
        # load pretrained model
        self.model_dict = self.model.state_dict()
        if self.args.init_weights is not None:
            pretrained_dict = torch.load(self.args.init_weights)['params']
        else:
            pre_base_dir = osp.join(log_base_dir, 'pre')
            pre_save_path1 = '_'.join([args.dataset, args.model_type])
            pre_save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
                str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
            pre_save_path = pre_base_dir + '/' + pre_save_path1 + '_' + pre_save_path2
            pretrained_dict = torch.load(osp.join(pre_save_path, 'max_iou.pth'))['params']
        pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()}
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict}
        
        print(pretrained_dict.keys())
        self.model_dict.update(pretrained_dict)
        self.model.load_state_dict(self.model_dict)    

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

    def _reset_metrics(self):
        #self.batch_time = AverageMeter()
        #self.data_time = AverageMeter()
        #self.total_loss = AverageMeter()
        self.total_inter, self.total_union = 0, 0
        self.total_correct, self.total_label = 0, 0
    
    def _update_seg_metrics(self, correct, labeled, inter, union):
        self.total_correct += correct
        self.total_label += labeled
        self.total_inter += inter
        self.total_union += union
    
    def _get_seg_metrics(self,n_class):
        self.n_class=n_class
        pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
        IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
        mIoU = IoU.mean()
        return {
            "Pixel_Accuracy": np.round(pixAcc, 3),
            "Mean_IoU": np.round(mIoU, 3),
            "Class_IoU": dict(zip(range(self.n_class), np.round(IoU, 3)))
        }
        
    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """  
        torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth'))           

    def train(self):
        """The function for the meta-train phase."""

        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['train_iou'] = []
        trlog['val_iou'] = []
        trlog['max_iou'] = 0.0
        trlog['max_iou_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)
                
        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()
            train_iou_averager = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            self._reset_metrics()
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number 
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, labels,_ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    labels=batch[1]
                    
                p = self.args.way*self.args.shot
                data_shot, data_query = data[:p], data[p:]
                label_shot,label=labels[:p],labels[p:]
                # Output logits for model
                par=data_shot, label_shot, data_query
                logits = self.model(par)
                
                # Calculate meta-train loss
                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits,label)
                
                # Calculate meta-train accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label, self.args.way)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()
                
                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(pixAcc)
                train_iou_averager.add(mIoU)

                # Print loss and accuracy till this step
                tqdm_gen.set_description('Epoch {}, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, train_loss_averager.item(), train_acc_averager.item()*100.0,train_iou_averager.item()))
                
                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()
            train_iou_averager = train_iou_averager.item()

            writer.add_scalar('data/train_loss (Meta)', float(train_loss_averager), epoch)
            writer.add_scalar('data/train_acc (Meta)', float(train_acc_averager)*100.0, epoch)  
            writer.add_scalar('data/train_iou (Meta)', float(train_iou_averager), epoch)
            
            # Start validation for this epoch, set model to eval mode
            self.model.eval()

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()
            val_iou_averager = Averager()
                
            # Print previous information
            if epoch % 1 == 0:
                print('Best Val Epoch {}, Best Val IoU={:.4f}'.format(trlog['max_iou_epoch'], trlog['max_iou']))
                
            # Run meta
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, labels,_ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                    labels=batch[1]
                p = self.args.way* self.args.shot
                data_shot, data_query = data[:p], data[p:]
                label_shot,label=labels[:p],labels[p:]
                
                par=data_shot, label_shot, data_query
                logits = self.model(par)
                
                # Calculate meta val loss
                #loss = self.FL(logits, label) + self.CD(logits,label) + self.LS(logits,label)
                loss = self.CD(logits,label)
                
                # Calculate meta-val accuracy
                self._reset_metrics()
                seg_metrics = eval_metrics(logits, label, self.args.way)
                self._update_seg_metrics(*seg_metrics)
                pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()

                val_loss_averager.add(loss.item())
                val_acc_averager.add(pixAcc)
                val_iou_averager.add(mIoU)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            val_iou_averager = val_iou_averager.item()
            
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss (Meta)', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc (Meta)', float(val_acc_averager)*100.0, epoch)  
            writer.add_scalar('data/val_iou (Meta)', float(val_iou_averager), epoch)
            
            # Print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f} IoU={:.4f}'.format(epoch, val_loss_averager, val_acc_averager*100.0,val_iou_averager))

            # Update best saved model
            if val_iou_averager > trlog['max_iou']:
                trlog['max_iou'] = val_iou_averager
                trlog['max_iou_epoch'] = epoch
                self.save_model('max_iou')
            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch'+str(epoch))

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)
            trlog['train_iou'].append(train_iou_averager)
            trlog['val_iou'].append(val_iou_averager)
            
            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 1 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(epoch / self.args.max_epoch)))

        writer.close()

    def eval(self):
        """The function for the meta-evaluate (test) phase."""
        # Load the logs
        trlog = torch.load(osp.join(self.args.save_path, 'trlog'))

        # Load meta-test set
        self.test_set = mDataset('test', self.args)
        self.sampler = CategoriesSampler(self.test_set.labeln, self.args.num_batch, self.args.way, self.args.teshot + self.args.test_query, self.args.teshot)
        self.loader = DataLoader(dataset=self.test_set, batch_sampler=self.sampler, num_workers=8, pin_memory=True)
        #self.loader = DataLoader(dataset=self.test_set,batch_size=10, shuffle=False, num_workers=8, pin_memory=True)
        # Set test accuracy recorder
        #test_acc_record = np.zeros((600,))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            self.model.load_state_dict(torch.load(self.args.eval_weights)['params'])
        else:
            self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_iou' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Start meta-test
        self._reset_metrics()
        count=1
        for i, batch in enumerate(self.loader, 1):
            if torch.cuda.is_available():
                data, labels,_ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
                labels=batch[1]
            p = self.args.teshot*self.args.way
            data_shot, data_query = data[:p], data[p:]
            label_shot,label=labels[:p],labels[p:]
            logits = self.model((data_shot, label_shot, data_query))
            seg_metrics = eval_metrics(logits, label, self.args.way)
            self._update_seg_metrics(*seg_metrics)
            pixAcc, mIoU, _ = self._get_seg_metrics(self.args.way).values()
            
            ave_acc.add(pixAcc)
            #test_acc_record[i-1] = acc
            #if i % 100 == 0:
                #print('batch {}: {Average Accuracy:.2f}({Pixel Accuracy:.2f} {IoU :.2f} )'.format(i, ave_acc.item() * 100.0, pixAcc * 100.0,mIoU))
                
            #Saving Test Image, Ground Truth Image and Predicted Image
            for j in range(len(data_query)):
                
                x1 = data_query[j].detach().cpu()
                y1 = label[j].detach().cpu()
                z1 = logits[j].detach().cpu()
                
                x = transforms.ToPILImage()(x1).convert("RGB")
                y = transforms.ToPILImage()(y1 /(1.0*(self.args.way-1))).convert("LA")
                im =  torch.tensor(np.argmax(np.array(z1),axis=0)/(1.0*(self.args.way-1))) 
                im =  im.type(torch.FloatTensor)
                z =  transforms.ToPILImage()(im).convert("LA")
                
                px=self.args.save_image_dir+str(count)+'a.jpg'
                py=self.args.save_image_dir+str(count)+'b.png'
                pz=self.args.save_image_dir+str(count)+'c.png'
                x.save(px)
                y.save(py)
                z.save(pz)
                count=count+1
class MetaTrainer(object):
    """The class that contains the code for the meta-train phase and meta-eval phase."""
    def __init__(self, args):
        # Set the folder to save the records and checkpoints
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        meta_base_dir = osp.join(log_base_dir, 'meta')
        if not osp.exists(meta_base_dir):
            os.mkdir(meta_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
        save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + \
            '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr1' + str(args.meta_lr1) + '_lr2' + str(args.meta_lr2) + \
            '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + \
            '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + \
            '_stepsize' + str(args.step_size) + '_' + args.meta_label
        args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        # Set args to be shareable in the class
        self.args = args

        # Load meta-train set
        self.trainset = Dataset('train', self.args)
        self.train_sampler = CategoriesSampler(
            self.trainset.label, self.args.num_batch, self.args.way,
            self.args.shot + self.args.train_query)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_sampler=self.train_sampler,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        # Load meta-val set
        self.valset = Dataset('val', self.args)
        self.val_sampler = CategoriesSampler(
            self.valset.label, 600, self.args.way,
            self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

        # Build meta-transfer learning model
        self.model = MtlLearner(self.args,res="high" if (self.args.distill_id or self.args.high_res) else "low",multi_gpu=len(args.gpu.split(","))>1,\
                                crossAtt=self.args.cross_att)

        if self.args.distill_id:
            #self.teacher = MtlLearner(self.args,res="low")
            #self.teacher.load_state_dict(torch.load(args.distill_id)["params"])

            self.teacher = MtlLearner(self.args,
                                      res="low",
                                      repVecNb=self.args.nb_parts_teach,
                                      multi_gpu=len(args.gpu.split(",")) > 1)
            bestTeach = "../models/{}/meta_{}_trial{}_max_acc.pth".format(
                self.args.exp_id, self.args.distill_id,
                self.args.best_trial_teach - 1)
            self.teacher.load_state_dict(torch.load(bestTeach)["params"])

        # Set optimizer
        self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.model.encoder.parameters())}, \
            {'params': self.model.base_learner.parameters(), 'lr': self.args.meta_lr2}], lr=self.args.meta_lr1)
        # Set learning rate scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=self.args.step_size,
            gamma=self.args.gamma)

        # load pretrained model without FC classifier
        self.model_dict = self.model.state_dict()
        if self.args.init_weights is not None:
            pretrained_dict = torch.load(self.args.init_weights)['params']

            pretrained_dict = {
                'encoder.' + k: v
                for k, v in pretrained_dict.items()
            }
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in self.model_dict
            }

            self.model_dict.update(pretrained_dict)
            self.model.load_state_dict(self.model_dict)

        # Set model to GPU
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()

            if self.args.distill_id:
                self.teacher = self.teacher.cuda()

        if self.args.cross_att:
            self.criterion = crossAttModule.CrossEntropyLoss()

    def crossAttLoss(self, ytest, cls_scores, labels_test, pids):
        loss1 = self.criterion(ytest, pids.view(-1))
        loss2 = self.criterion(cls_scores, labels_test.view(-1))
        loss = loss1 + 0.5 * loss2
        return loss

    def one_hot(self, labels_train):
        """
        Turn the labels_train to one-hot encoding.
        Args:
            labels_train: [batch_size, num_train_examples]
        Return:
            labels_train_1hot: [batch_size, num_train_examples, K]
        """
        labels_train = labels_train.cpu()
        nKnovel = 1 + labels_train.max()
        labels_train_1hot_size = list(labels_train.size()) + [
            nKnovel,
        ]
        labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim())
        labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_(
            len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1)
        return labels_train_1hot

    def save_model(self, name):
        """The function to save checkpoints.
        Args:
          name: the name for saved checkpoint
        """
        #torch.save(dict(params=self.model.encoder.state_dict()), osp.join(self.args.save_path, name + '.pth'))
        torch.save(
            dict(params=self.model.state_dict()),
            "../models/{}/meta_{}_trial{}_{}.pth".format(
                self.args.exp_id, self.args.model_id, self.args.trial_number,
                name))

    def train(self, trial):
        """The function for the meta-train phase."""

        # Set the meta-train log
        trlog = {}
        trlog['args'] = vars(self.args)
        trlog['train_loss'] = []
        trlog['val_loss'] = []
        trlog['train_acc'] = []
        trlog['val_acc'] = []
        trlog['max_acc'] = 0.0
        trlog['max_acc_epoch'] = 0

        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        # Set tensorboardX
        writer = SummaryWriter(comment=self.args.save_path)

        # Generate the labels for train set of the episodes
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        worstClasses = []

        # Start meta-train
        for epoch in range(1, self.args.max_epoch + 1):
            # Update learning rate
            self.lr_scheduler.step()
            # Set the model to train mode
            self.model.train()
            # Set averager classes to record training losses and accuracies
            train_loss_averager = Averager()
            train_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-train updates
            label = torch.arange(self.args.way).repeat(self.args.train_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                # Update global count number
                global_count = global_count + 1
                if torch.cuda.is_available():
                    data, targ = [_.cuda() for _ in batch]
                else:
                    data, targ = batch
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]
                # Output logits for model
                if self.args.cross_att:
                    label_one_hot = self.one_hot(label).to(label.device)
                    ytest, cls_scores, logits = self.model(
                        (data_shot, label_shot, data_query),
                        ytest=label_one_hot)
                    pids = label_shot
                    loss = self.crossAttLoss(ytest, cls_scores, label, pids)
                    logits = logits[0]
                else:
                    logits = self.model((data_shot, label_shot, data_query))
                    # Calculate meta-train loss
                    loss = F.cross_entropy(logits, label)

                if self.args.distill_id:
                    teachLogits = self.teacher(
                        (data_shot, label_shot, data_query))
                    kl = F.kl_div(F.log_softmax(logits / self.args.kl_temp,
                                                dim=1),
                                  F.softmax(teachLogits / self.args.kl_temp,
                                            dim=1),
                                  reduction="batchmean")
                    loss = (kl * self.args.kl_interp * self.args.kl_temp *
                            self.args.kl_temp + loss *
                            (1 - self.args.kl_interp))

                acc = count_acc(logits, label)
                # Write the tensorboardX records
                writer.add_scalar('data/loss', float(loss), global_count)
                writer.add_scalar('data/acc', float(acc), global_count)
                # Print loss and accuracy for this step
                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} Acc={:.4f}'.format(
                        epoch, loss.item(), acc))

                # Add loss and accuracy for the averagers
                train_loss_averager.add(loss.item())
                train_acc_averager.add(acc)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if self.args.hard_tasks:
                    if len(worstClasses) == self.args.way:
                        inds = self.train_sampler.hardBatch(worstClasses)
                        batch = [self.trainset[i][0] for i in inds]
                        data_shot, data_query = data[:p], data[p:]
                        logits = self.model(
                            (data_shot, label_shot, data_query))
                        loss = F.cross_entropy(logits, label)
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()
                        worstClasses = []
                    else:
                        error_mat = (logits.argmax(dim=1) == label).view(
                            self.args.train_query, self.args.way)
                        worst = error_mat.float().mean(dim=0).argmin()
                        worst_trueInd = targ[worst]
                        worstClasses.append(worst_trueInd)

            # Update the averagers
            train_loss_averager = train_loss_averager.item()
            train_acc_averager = train_acc_averager.item()

            # Start validation for this epoch, set model to eval mode
            self.model.eval()

            # Set averager classes to record validation losses and accuracies
            val_loss_averager = Averager()
            val_acc_averager = Averager()

            # Generate the labels for test set of the episodes during meta-val for this epoch
            label = torch.arange(self.args.way).repeat(self.args.val_query)
            if torch.cuda.is_available():
                label = label.type(torch.cuda.LongTensor)
            else:
                label = label.type(torch.LongTensor)

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val Acc={:.4f}'.format(
                    trlog['max_acc_epoch'], trlog['max_acc']))
            # Run meta-validation
            for i, batch in enumerate(self.val_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]
                p = self.args.shot * self.args.way
                data_shot, data_query = data[:p], data[p:]

                if self.args.cross_att:
                    label_one_hot = self.one_hot(label).to(label.device)
                    ytest, cls_scores, logits = self.model(
                        (data_shot, label_shot, data_query),
                        ytest=label_one_hot)
                    pids = label_shot
                    loss = self.crossAttLoss(ytest, cls_scores, label, pids)
                    logits = logits[0]
                else:
                    logits = self.model((data_shot, label_shot, data_query))
                    loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)

                val_loss_averager.add(loss.item())
                val_acc_averager.add(acc)

            # Update validation averagers
            val_loss_averager = val_loss_averager.item()
            val_acc_averager = val_acc_averager.item()
            # Write the tensorboardX records
            writer.add_scalar('data/val_loss', float(val_loss_averager), epoch)
            writer.add_scalar('data/val_acc', float(val_acc_averager), epoch)
            # Print loss and accuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} Acc={:.4f}'.format(
                epoch, val_loss_averager, val_acc_averager))

            # Update best saved model
            if val_acc_averager > trlog['max_acc']:
                trlog['max_acc'] = val_acc_averager
                trlog['max_acc_epoch'] = epoch
                self.save_model('max_acc')

            # Update the logs
            trlog['train_loss'].append(train_loss_averager)
            trlog['train_acc'].append(train_acc_averager)
            trlog['val_loss'].append(val_loss_averager)
            trlog['val_acc'].append(val_acc_averager)

            # Save log
            torch.save(trlog, osp.join(self.args.save_path, 'trlog'))

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(),
                    timer.measure(epoch / self.args.max_epoch)))

            trial.report(val_acc_averager, epoch)

        writer.close()

    def eval(self, gradcam=False, rise=False, test_on_val=False):
        """The function for the meta-eval phase."""
        # Load the logs
        if os.path.exists(osp.join(self.args.save_path, 'trlog')):
            trlog = torch.load(osp.join(self.args.save_path, 'trlog'))
        else:
            trlog = None

        torch.manual_seed(1)
        np.random.seed(1)
        # Load meta-test set
        test_set = Dataset('val' if test_on_val else 'test', self.args)
        sampler = CategoriesSampler(test_set.label, 600, self.args.way,
                                    self.args.shot + self.args.val_query)
        loader = DataLoader(test_set,
                            batch_sampler=sampler,
                            num_workers=8,
                            pin_memory=True)

        # Set test accuracy recorder
        test_acc_record = np.zeros((600, ))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            weights = self.addOrRemoveModule(
                self.model,
                torch.load(self.args.eval_weights)['params'])
            self.model.load_state_dict(weights)
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_acc' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Generate labels
        label = torch.arange(self.args.way).repeat(self.args.val_query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        if gradcam:
            self.model.layer3 = self.model.encoder.layer3
            model_dict = dict(type="resnet",
                              arch=self.model,
                              layer_name='layer3')
            grad_cam = GradCAM(model_dict, True)
            grad_cam_pp = GradCAMpp(model_dict, True)
            self.model.features = self.model.encoder
            guided = GuidedBackprop(self.model)
        if rise:
            self.model.layer3 = self.model.encoder.layer3
            score_mod = ScoreCam(self.model)

        # Start meta-test
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = self.args.way * self.args.shot
            data_shot, data_query = data[:k], data[k:]

            if i % 5 == 0:
                suff = "_val" if test_on_val else ""

                if self.args.rep_vec or self.args.cross_att:
                    print('batch {}: {:.2f}({:.2f})'.format(
                        i,
                        ave_acc.item() * 100, acc * 100))

                    if self.args.cross_att:
                        label_one_hot = self.one_hot(label).to(label.device)
                        _, _, logits, simMapQuer, simMapShot, normQuer, normShot = self.model(
                            (data_shot, label_shot, data_query),
                            ytest=label_one_hot,
                            retSimMap=True)
                    else:
                        logits, simMapQuer, simMapShot, normQuer, normShot, fast_weights = self.model(
                            (data_shot, label_shot, data_query),
                            retSimMap=True)

                    torch.save(
                        simMapQuer,
                        "../results/{}/{}_simMapQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        simMapShot,
                        "../results/{}/{}_simMapShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        data_query, "../results/{}/{}_dataQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        data_shot, "../results/{}/{}_dataShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normQuer, "../results/{}/{}_normQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normShot, "../results/{}/{}_normShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                else:
                    logits, normQuer, normShot, fast_weights = self.model(
                        (data_shot, label_shot, data_query),
                        retFastW=True,
                        retNorm=True)
                    torch.save(
                        normQuer, "../results/{}/{}_normQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normShot, "../results/{}/{}_normShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))

                if gradcam:
                    print("Saving gradmaps", i)
                    allMasks, allMasks_pp, allMaps = [], [], []
                    for l in range(len(data_query)):
                        allMasks.append(
                            grad_cam(data_query[l:l + 1], fast_weights, None))
                        allMasks_pp.append(
                            grad_cam_pp(data_query[l:l + 1], fast_weights,
                                        None))
                        allMaps.append(
                            guided.generate_gradients(data_query[l:l + 1],
                                                      fast_weights))
                    allMasks = torch.cat(allMasks, dim=0)
                    allMasks_pp = torch.cat(allMasks_pp, dim=0)
                    allMaps = torch.cat(allMaps, dim=0)

                    torch.save(
                        allMasks, "../results/{}/{}_gradcamQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        allMasks_pp,
                        "../results/{}/{}_gradcamppQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        allMaps, "../results/{}/{}_guidedQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))

                if rise:
                    print("Saving risemaps", i)
                    allScore = []
                    for l in range(len(data_query)):
                        allScore.append(
                            score_mod(data_query[l:l + 1], fast_weights))

            else:
                if self.args.cross_att:
                    label_one_hot = self.one_hot(label).to(label.device)
                    _, _, logits = self.model(
                        (data_shot, label_shot, data_query),
                        ytest=label_one_hot)
                else:
                    logits = self.model((data_shot, label_shot, data_query))

            acc = count_acc(logits, label)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc

        # Calculate the confidence interval, update the logs
        m, pm = compute_confidence_interval(test_acc_record)
        if trlog is not None:
            print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(
                trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item()))
        print('Test Acc {:.4f} + {:.4f}'.format(m, pm))

        return m

    def addOrRemoveModule(self, net, weights):

        exKeyWei = None
        for key in weights:
            if key.find("encoder") != -1:
                exKeyWei = key
                break
            else:
                print(key)

        exKeyNet = None
        for key in net.state_dict():
            if key.find("encoder") != -1:
                exKeyNet = key
                break

        print(exKeyWei, exKeyNet)

        if exKeyWei.find("module") != -1 and exKeyNet.find("module") == -1:
            #remove module
            newWeights = {}
            for param in weights:
                newWeights[param.replace("module.", "")] = weights[param]
            weights = newWeights

        if exKeyWei.find("module") == -1 and exKeyNet.find("module") != -1:
            #add module
            newWeights = {}
            for param in weights:
                if param.find("encoder") != -1:
                    param_split = param.split(".")
                    newParam = param_split[0] + "." + "module." + ".".join(
                        param_split[1:])
                    newWeights[newParam] = weights[param]
                else:
                    newWeights[param] = weights[param]
            weights = newWeights

        return weights