def __init__(self,
                 model_name,
                 fold,
                 model_path,
                 class_num=4,
                 tta_flag=False):
        self.model_name = model_name
        self.fold = fold
        self.model_path = model_path
        self.class_num = class_num
        self.tta_flag = tta_flag

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.classify_model = ClassifyResNet(model_name, encoder_weights=None)
        if torch.cuda.is_available():
            self.classify_model = torch.nn.DataParallel(self.classify_model)

        self.classify_model.to(self.device)

        self.classify_model_path = os.path.join(
            self.model_path,
            '%s_classify_fold%d_best.pth' % (self.model_name, self.fold))
        self.solver = Solver(self.classify_model)
        self.classify_model = self.solver.load_checkpoint(
            self.classify_model_path)
        self.classify_model.eval()
Example #2
0
    def __init__(self,
                 model_name,
                 fold,
                 model_path,
                 class_num=4,
                 tta_flag=False):
        ''' 处理当前fold一个batch的数据分类结果

        :param model_name: 当前的模型名称
        :param fold: 当前的折数
        :param model_path: 存放所有模型的路径
        :param class_num: 类别总数
        '''
        self.model_name = model_name
        self.fold = fold
        self.model_path = model_path
        self.class_num = class_num
        self.tta_flag = tta_flag

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型及其权重
        self.classify_model = ClassifyResNet(model_name, encoder_weights=None)
        if torch.cuda.is_available():
            self.classify_model = torch.nn.DataParallel(self.classify_model)

        self.classify_model.to(self.device)

        self.classify_model_path = os.path.join(
            self.model_path,
            '%s_classify_fold%d_best.pth' % (self.model_name, self.fold))
        self.solver = Solver(self.classify_model)
        self.classify_model = self.solver.load_checkpoint(
            self.classify_model_path)
        self.classify_model.eval()
Example #3
0
class Get_Classify_Results():
    def __init__(self,
                 model_name,
                 fold,
                 model_path,
                 class_num=4,
                 tta_flag=False):
        ''' 处理当前fold一个batch的数据分类结果

        :param model_name: 当前的模型名称
        :param fold: 当前的折数
        :param model_path: 存放所有模型的路径
        :param class_num: 类别总数
        '''
        self.model_name = model_name
        self.fold = fold
        self.model_path = model_path
        self.class_num = class_num
        self.tta_flag = tta_flag

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型及其权重
        self.classify_model = ClassifyResNet(model_name, encoder_weights=None)
        if torch.cuda.is_available():
            self.classify_model = torch.nn.DataParallel(self.classify_model)

        self.classify_model.to(self.device)

        self.classify_model_path = os.path.join(
            self.model_path,
            '%s_classify_fold%d_best.pth' % (self.model_name, self.fold))
        self.solver = Solver(self.classify_model)
        self.classify_model = self.solver.load_checkpoint(
            self.classify_model_path)
        self.classify_model.eval()

    def get_classify_results(self, images, thrshold=0.5):
        ''' 处理当前fold一个batch的数据分类结果

        :param images: 一个batch的数据,维度为[batch, channels, height, width]
        :param thrshold: 分类模型的阈值
        :return: predict_classes: 一个batch的数据经过分类模型后的结果,维度为[batch, class_num]
        '''
        if self.tta_flag:
            predict_classes = self.solver.tta(images, seg=False)
        else:
            predict_classes = self.solver.forward(images)
        predict_classes = predict_classes > thrshold
        return predict_classes
    def __init__(self, config, fold):
        # 加载网络模型
        self.model_name = config.model_name
        self.model = ClassifyResNet(self.model_name, 4, training=True)
        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            self.model = self.model.cuda()

        # 加载超参数
        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch
        self.fold = fold

        # 实例化实现各种子函数的 solver 类
        self.solver = Solver(self.model)

        # 加载损失函数
        self.criterion = ClassifyLoss()

        # 创建保存权重的路径
        self.model_path = os.path.join(config.save_path, config.model_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        # 保存json文件和初始化tensorboard
        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S-%d}-classify".format(
            datetime.datetime.now(), fold)
        self.writer = SummaryWriter(
            log_dir=os.path.join(self.model_path, TIMESTAMP))
        with codecs.open(self.model_path + '/' + TIMESTAMP + '.json', 'w',
                         "utf-8") as json_file:
            json.dump({k: v
                       for k, v in config._get_kwargs()},
                      json_file,
                      ensure_ascii=False)

        self.max_accuracy_valid = 0
        # 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
        self.seed = int(time.time())
        # self.seed = 1570421136
        seed_torch(self.seed)
        with open(self.model_path + '/' + TIMESTAMP + '.pkl', 'wb') as f:
            pickle.dump({'seed': self.seed}, f, -1)
class Get_Classify_Results():
    def __init__(self,
                 model_name,
                 fold,
                 model_path,
                 class_num=4,
                 tta_flag=False):
        self.model_name = model_name
        self.fold = fold
        self.model_path = model_path
        self.class_num = class_num
        self.tta_flag = tta_flag

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.classify_model = ClassifyResNet(model_name, encoder_weights=None)
        if torch.cuda.is_available():
            self.classify_model = torch.nn.DataParallel(self.classify_model)

        self.classify_model.to(self.device)

        self.classify_model_path = os.path.join(
            self.model_path,
            '%s_classify_fold%d_best.pth' % (self.model_name, self.fold))
        self.solver = Solver(self.classify_model)
        self.classify_model = self.solver.load_checkpoint(
            self.classify_model_path)
        self.classify_model.eval()

    def get_classify_results(self, images, thrshold=0.5):
        if self.tta_flag:
            predict_classes = self.solver.tta(images, seg=False)
        else:
            predict_classes = self.solver.forward(images)
        predict_classes = predict_classes > thrshold
        return predict_classes
class TrainVal():
    def __init__(self, config, fold):

        self.model_name = config.model_name
        self.model = ClassifyResNet(self.model_name, 4, training=True)
        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            self.model = self.model.cuda()

        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch
        self.fold = fold

        self.solver = Solver(self.model)

        self.criterion = ClassifyLoss()

        self.model_path = os.path.join(config.save_path, config.model_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        #
        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S-%d}-classify".format(
            datetime.datetime.now(), fold)
        self.writer = SummaryWriter(
            log_dir=os.path.join(self.model_path, TIMESTAMP))
        with codecs.open(self.model_path + '/' + TIMESTAMP + '.json', 'w',
                         "utf-8") as json_file:
            json.dump({k: v
                       for k, v in config._get_kwargs()},
                      json_file,
                      ensure_ascii=False)

        self.max_accuracy_valid = 0
        self.seed = int(time.time())
        # self.seed = 1570421136
        seed_torch(self.seed)
        with open(self.model_path + '/' + TIMESTAMP + '.pkl', 'wb') as f:
            pickle.dump({'seed': self.seed}, f, -1)

    def train(self, train_loader, valid_loader):

        optimizer = optim.Adam(self.model.module.parameters(),
                               self.lr,
                               weight_decay=self.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.epoch + 10)
        global_step = 0

        for epoch in range(self.epoch):
            epoch += 1
            epoch_loss = 0
            self.model.train(True)

            tbar = tqdm.tqdm(train_loader)
            for i, (images, labels) in enumerate(tbar):

                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels, labels_predict,
                                            self.criterion)
                epoch_loss += loss.item()
                self.solver.backword(optimizer, loss)

                self.writer.add_scalar('train_loss', loss.item(),
                                       global_step + i)
                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'params_group_%d' % (
                        group_ind) + ': %.12f, ' % (param_group['lr'])
                descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (
                    self.fold, loss.item(), params_groups_lr)
                tbar.set_description(desc=descript)

            lr_scheduler.step()
            global_step += len(train_loader)

            print('Finish Epoch [%d/%d], Average Loss: %.7f' %
                  (epoch, self.epoch, epoch_loss / len(tbar)))

            class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \
                self.validation(valid_loader)

            if accuracy > self.max_accuracy_valid:
                is_best = True
                self.max_accuracy_valid = accuracy
            else:
                is_best = False

            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
                'max_accuracy_valid': self.max_accuracy_valid,
            }

            self.solver.save_checkpoint(
                os.path.join(
                    self.model_path,
                    '%s_classify_fold%d.pth' % (self.model_name, self.fold)),
                state, is_best)
            self.writer.add_scalar('valid_loss', loss_valid, epoch)
            self.writer.add_scalar('valid_accuracy', accuracy, epoch)
            self.writer.add_scalar('valid_class_0_accuracy', class_accuracy[0],
                                   epoch)
            self.writer.add_scalar('valid_class_1_accuracy', class_accuracy[1],
                                   epoch)
            self.writer.add_scalar('valid_class_2_accuracy', class_accuracy[2],
                                   epoch)
            self.writer.add_scalar('valid_class_3_accuracy', class_accuracy[3],
                                   epoch)

    def validation(self, valid_loader):

        self.model.eval()
        meter = Meter()
        tbar = tqdm.tqdm(valid_loader)
        loss_sum = 0

        with torch.no_grad():
            for i, (images, labels) in enumerate(tbar):

                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels, labels_predict,
                                            self.criterion)
                loss_sum += loss.item()

                meter.update(labels, labels_predict.cpu())

                descript = "Val Loss: {:.7f}".format(loss.item())
                tbar.set_description(desc=descript)
        loss_mean = loss_sum / len(tbar)

        class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy = meter.get_metrics(
        )
        print(
            "Class_0_accuracy: %0.4f | Class_1_accuracy: %0.4f | Class_2_accuracy: %0.4f | Class_3_accuracy: %0.4f | "
            "Negative accuracy: %0.4f | positive accuracy: %0.4f | accuracy: %0.4f"
            % (class_accuracy[0], class_accuracy[1], class_accuracy[2],
               class_accuracy[3], neg_accuracy, pos_accuracy, accuracy))
        return class_neg_accuracy, class_pos_accuracy, class_accuracy, \
               neg_accuracy, pos_accuracy, accuracy, loss_mean
if __name__ == "__main__":
    data_folder = "/home/apple/program/MXQ/Competition/Kaggle/Steal-Defect/Kaggle-Steel-Defect-Detection/datasets/Steel_data"
    df_path = "/home/apple/program/MXQ/Competition/Kaggle/Steal-Defect/Kaggle-Steel-Defect-Detection/datasets/Steel_data/train.csv"
    test_df = pd.read_csv('./datasets/Steel_data/sample_submission.csv')
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    test_dataset = TestDataset('./datasets/Steel_data/test_images', test_df,
                               mean, std)
    dataloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=20,
                                             shuffle=True,
                                             num_workers=8,
                                             pin_memory=True)

    model = ClassifyResNet('unet_resnet34', 4, training=False)
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    pth_path = "checkpoints/unet_resnet34/unet_resnet34_classify_fold1.pth"
    checkpoint = torch.load(pth_path)
    model.module.load_state_dict(checkpoint['state_dict'])

    class_test = ClassifyTest(model, [0.5, 0.5, 0.5, 0.5], True)
    # 直接对一整个数据集进行预测
    # image_id, predict_label = class_test.predict(dataloader)
    # 按照mini-batch的方式进行预测
    class_dataloader = classify_provider(data_folder, df_path, mean, std, 20,
                                         8, 5)
    for fold_index, [train_dataloader,
                     val_dataloader] in enumerate(class_dataloader):
        train_bar = tqdm(val_dataloader)