Example #1
0
class Model(BaseModel):
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.classifier = Classifier()  #.cuda(device=opt.device)
        #####################
        #    Init weights
        #####################
        # self.classifier.apply(weights_init)

        print_network(self.classifier)

        self.optimizer = optim.Adam(self.classifier.parameters(),
                                    lr=opt.lr,
                                    betas=(0.95, 0.999))

        # load networks
        if opt.load:
            pretrained_path = opt.load
            self.load_network(self.classifier, 'G', opt.which_epoch,
                              pretrained_path)
            # if self.training:
            #     self.load_network(self.discriminitor, 'D', opt.which_epoch, pretrained_path)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)

    def update(self, input, label):

        predicted = self.classifier(input)
        loss = criterionCE(predicted, label)

        self.avg_meters.update({'Cross Entropy': loss.item()})

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'predicted': predicted}

    def forward(self, x):
        return self.classifier(x)

    def save(self, which_epoch):
        self.save_network(self.classifier, 'G', which_epoch)
        # self.save_network(self.discriminitor, 'D', which_epoch)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        # for param_group in self.d_optimizer.param_groups:
        #     param_group['lr'] = lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
Example #2
0
class Model(BaseModel):
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.classifier = Classifier(opt.model)  #.cuda(device=opt.device)
        #####################
        #    Init weights
        #####################
        # self.classifier.apply(weights_init)

        print_network(self.classifier)

        self.optimizer = get_optimizer(opt, self.classifier)
        self.scheduler = get_scheduler(opt, self.optimizer)

        # load networks
        # if opt.load:
        #     pretrained_path = opt.load
        #     self.load_network(self.classifier, 'G', opt.which_epoch, pretrained_path)
        # if self.training:
        #     self.load_network(self.discriminitor, 'D', opt.which_epoch, pretrained_path)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)

        # with open('datasets/class_weight.pkl', 'rb') as f:
        #     class_weight = pickle.load(f, encoding='bytes')
        #     class_weight = np.array(class_weight, dtype=np.float32)
        #     class_weight = torch.from_numpy(class_weight).to(opt.device)
        #     if opt.class_weight:
        #         self.criterionCE = nn.CrossEntropyLoss(weight=class_weight)
        #     else:
        self.criterionCE = nn.CrossEntropyLoss()

    def update(self, input, label):

        # loss_ce = self.criterionCE(predicted, label)
        # loss_ce = label_smooth_loss(predicted, label)
        # loss = loss_ce
        predicted = self.classifier(input)
        loss_ce = label_smooth_loss(predicted, label)
        loss = loss_ce

        self.avg_meters.update({'CE loss(label smooth)': loss_ce.item()})

        # if opt.weight_range:
        #     _, _, range_loss = criterionRange(predicted, label)
        #     range_loss = range_loss * opt.weight_range
        #     loss += range_loss
        #     self.avg_meters.update({'Range': range_loss.item()})

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'predicted': predicted}

    def forward(self, x):
        return self.classifier(x)

    def load(self, ckpt_path):
        load_dict = torch.load(ckpt_path, map_location=opt.device)
        self.classifier.load_state_dict(load_dict['classifier'])
        if opt.resume:
            self.optimizer.load_state_dict(load_dict['optimizer'])
            self.scheduler.load_state_dict(load_dict['scheduler'])
            epoch = load_dict['epoch']
            utils.color_print(
                'Load checkpoint from %s, resume training.' % ckpt_path, 3)
        else:
            epoch = load_dict['epoch']
            utils.color_print('Load checkpoint from %s.' % ckpt_path, 3)

        return epoch

    def save(self, which_epoch):
        # self.save_network(self.classifier, 'G', which_epoch)
        save_filename = f'{which_epoch}_{opt.model}.pt'
        save_path = os.path.join(self.save_dir, save_filename)
        save_dict = OrderedDict()
        save_dict['classifier'] = self.classifier.state_dict()
        # save_dict['discriminitor'] = self.discriminitor.state_dict()
        save_dict['optimizer'] = self.optimizer.state_dict()
        save_dict['scheduler'] = self.scheduler.state_dict()
        save_dict['epoch'] = which_epoch
        torch.save(save_dict, save_path)
        utils.color_print(f'Save checkpoint "{save_path}".', 3)
Example #3
0
class Model(BaseModel):
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.direct_feature = DirectFeature(opt.model)
        self.feature_nums = self.direct_feature.get_feature_num()
        self.meta_embedding = MetaEmbedding(self.feature_nums, 50030)

        print_network(self.direct_feature)
        print_network(self.meta_embedding)

        # TODO: 这里学习率是不是可以调成 direct_feature 0.01 meta_embedding 0.1
        # self.optimizer = optim.SGD(chain(self.direct_feature.parameters(), self.meta_embedding.parameters()),
        #                            lr=0.01, momentum=0.9, weight_decay=0.0005)

        self.optimizer = optim.Adam(chain(self.direct_feature.parameters(), self.meta_embedding.parameters()),
                                   lr=0.01)

        self.scheduler = get_scheduler(opt, self.optimizer)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)

        # different weight for different classes
        self.criterionCE = nn.CrossEntropyLoss()

    @staticmethod
    def class_count(dataset):
        labels = np.array(dataset.labels)
        class_data_num = []
        for l in np.unique(labels):
            class_data_num.append(len(labels[labels == l]))
        return class_data_num

    def centroids_cal(self, dataloader):
        #在embedding模式下生成mem,建议在train里的if opt.load 和 model.train()之间添加model.centroids_cal(train_dataloader)
        centroids = torch.zeros(50030, self.feature_nums).to(opt.device)

        # print('Calculating centroids.')

        self.eval()

        with torch.set_grad_enabled(False):

            for i, data in enumerate(dataloader):
                utils.progress_bar(i, len(dataloader), 'Calculating centroids...')
                inputs, labels = data['input'], data['label']
                inputs = inputs.to(opt.device)
                direct_features = self.direct_feature(inputs)
                for i in range(len(labels)):
                    label = labels[i]
                    centroids[label] += direct_features[i]

        # Average summed features with class count
        centroids /= torch.tensor(self.class_count(train_dataset)).float().unsqueeze(1).to(opt.device)  #class count为每一类的样本数,需要单独写。
        self.mem = centroids

    def update(self, input, label):

        predicted = self.forward(input)
        # TODO:loss加上DiscCentroidsLoss
        loss = self.criterionCE(predicted, label)

        self.avg_meters.update({'Cross Entropy': loss.item()})

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'predicted': predicted}

    def forward(self, x):
        direct_feature = self.direct_feature(x)
        # meta embedding里带有了classifier
        logits, _ = self.meta_embedding(direct_feature, self.mem)
        return logits

    def load(self, ckpt_path):
        load_dict = torch.load(ckpt_path, map_location=opt.device)
        if 'direct_feature' not in load_dict:  # 旧的checkpoint
            direct_feature = load_dict['classifier']
            classifier_dict = OrderedDict()
            classifier_dict['weight'] = direct_feature.pop('network.fc.weight')
            classifier_dict['bias'] = direct_feature.pop('network.fc.bias')
            self.direct_feature.load_state_dict(direct_feature)
            self.meta_embedding.fc_hallucinator.load_state_dict(classifier_dict)
            # 如果是从stage1 load的,计算centroids
            self.centroids_cal(train_dataloader_plain)

        else:  # 新的checkpoint
            self.direct_feature.load_state_dict(load_dict['direct_feature'])
            self.meta_embedding.load_state_dict(load_dict['meta_embedding'])
            self.mem = load_dict['centroids']

        if opt.resume:
            self.optimizer.load_state_dict(load_dict['optimizer'])
            self.scheduler.load_state_dict(load_dict['scheduler'])
            epoch = load_dict['epoch']
            utils.color_print('Load checkpoint from %s, resume training.' % ckpt_path, 3)
        else:
            epoch = load_dict['epoch']
            utils.color_print('Load checkpoint from %s.' % ckpt_path, 3)

        return epoch

    def save(self, which_epoch):
        save_filename = f'{which_epoch}_{opt.model}.pt'
        save_path = os.path.join(self.save_dir, save_filename)
        save_dict = OrderedDict()
        save_dict['direct_feature'] = self.direct_feature.state_dict()
        save_dict['meta_embedding'] = self.meta_embedding.state_dict()
        save_dict['centroids'] = self.mem

        save_dict['optimizer'] = self.optimizer.state_dict()
        save_dict['scheduler'] = self.scheduler.state_dict()
        save_dict['epoch'] = which_epoch
        torch.save(save_dict, save_path)
        utils.color_print(f'Save checkpoint "{save_path}".', 3)
Example #4
0
class Model(BaseModel):
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.classifier = Classifier()
        #####################
        #    Init weights
        #####################
        # self.classifier.apply(weights_init)

        print_network(self.classifier)

        self.optimizer = optim.Adam(self.classifier.parameters(),
                                    lr=opt.lr,
                                    betas=(0.95, 0.999))
        self.scheduler = get_scheduler(opt, self.optimizer)

        # load networks
        # if opt.load:
        #     pretrained_path = opt.load
        #     self.load_network(self.classifier, 'G', opt.which_epoch, pretrained_path)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)

    def update(self, input, label):

        predicted = self.classifier(input)
        loss = criterionCE(predicted, label)

        self.avg_meters.update({'Cross Entropy': loss.item()})

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'predicted': predicted}

    def forward(self, x):
        return self.classifier(x)

    def load(self, ckpt_path):
        load_dict = torch.load(ckpt_path, map_location=opt.device)
        self.classifier.load_state_dict(load_dict['classifier'])
        if opt.resume:
            self.optimizer.load_state_dict(load_dict['optimizer'])
            self.scheduler.load_state_dict(load_dict['scheduler'])
            epoch = load_dict['epoch']
            utils.color_print(
                'Load checkpoint from %s, resume training.' % ckpt_path, 3)
        else:
            epoch = load_dict['epoch']
            utils.color_print('Load checkpoint from %s.' % ckpt_path, 3)

        return epoch

    def save(self, which_epoch):
        # self.save_network(self.classifier, 'G', which_epoch)
        save_filename = f'{which_epoch}_{opt.model}.pt'
        save_path = os.path.join(self.save_dir, save_filename)
        save_dict = OrderedDict()
        save_dict['classifier'] = self.classifier.state_dict()
        # save_dict['discriminitor'] = self.discriminitor.state_dict()
        save_dict['optimizer'] = self.optimizer.state_dict()
        save_dict['scheduler'] = self.scheduler.state_dict()
        save_dict['epoch'] = which_epoch
        torch.save(save_dict, save_path)
        utils.color_print(f'Save checkpoint "{save_path}".', 3)