Ejemplo n.º 1
0
class Predict():
    def __init__(self, config):
        self.layer = config.layer
        self.method = config.method
        self.model = Classify_model(self.layer, self.method, training=True)
        # if torch.cuda.is_available():
        #     self.model = torch.nn.DataParallel(self.model)
        #     self.model = self.model.cuda()

        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch
        self.max_accuracy_valid = 0
        self.solver = Solver(self.model, self.method)
        self.criterion = torch.nn.MSELoss()
        self.weight_path = config.weight_path

    def validation(self, test_loader):
        self.model.eval()
        self.model.train(False)
        checkpoint = torch.load(self.weight_path, map_location=torch.device('cpu'))
        self.model.load_state_dict(checkpoint['state_dict'])
        meter = Meter()
        tbar = tqdm.tqdm(test_loader, ncols=80)
        loss_sum = 0

        with torch.no_grad():
            for i, (x, labels) in enumerate(tbar):
                labels_predict = self.solver.forward(x)
                labels_predict = torch.sigmoid(labels_predict)
                loss = self.solver.cal_loss(labels, labels_predict, self.criterion)
                loss_sum += loss.item()

                meter.update(labels, labels_predict.cpu())

                descript = "Val Loss: {:.7f}".format(loss.item())
                tbar.set_description(desc=descript)
        loss_mean = loss_sum / len(tbar)

        class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy = meter.get_metrics()
        print(
            "Class_0_accuracy: %0.4f | Class_1_accuracy: %0.4f | Negative accuracy: %0.4f | positive accuracy: %0.4f | accuracy: %0.4f" %
            (class_accuracy[0], class_accuracy[1], neg_accuracy, pos_accuracy, accuracy))
        return class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_mean

    def predict(self, dataset):
        self.model.eval()
        self.model.train(False)
        checkpoint = torch.load(self.weight_path, map_location=torch.device('cpu'))
        self.model.load_state_dict(checkpoint['state_dict'])

        with torch.no_grad():
            labels_predict = self.solver.forward(dataset)
            labels_predict = torch.sigmoid(labels_predict)

        return labels_predict
class TrainVal():
    def __init__(self, config, fold):
        '''
        Args:
            config: 配置参数
            fold: 折数
        '''
        # 加载网络模型
        self.model_name = config.model_name
        self.model = Model(self.model_name).create_model()

        # 加载超参数
        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch
        self.fold = fold

        # 创建保存权重的路径
        self.model_path = os.path.join(config.save_path, config.model_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        
        if config.resume:
            weight_path = os.path.join(self.model_path, config.resume)
            self.load_weight(weight_path)

        # 实例化实现各种子函数的 solver 类
        self.solver = Solver(self.model)

        # 加载损失函数
        self.criterion = torch.nn.BCEWithLogitsLoss()

        # 保存json文件和初始化tensorboard
        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S-%d}".format(datetime.datetime.now(), fold)
        self.writer = SummaryWriter(log_dir=os.path.join(self.model_path, TIMESTAMP))
        with codecs.open(self.model_path + '/'+ TIMESTAMP + '.json', 'w', "utf-8") as json_file:
            json.dump({k: v for k, v in config._get_kwargs()}, json_file, ensure_ascii=False)

        self.max_dice_valid = 0

        # 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
        self.seed = 1570421136
        seed_torch(self.seed)
        with open(self.model_path + '/'+ TIMESTAMP + '.pkl','wb') as f:
            pickle.dump({'seed': self.seed}, f, -1)

    def train(self, train_loader, valid_loader):
        ''' 完成模型的训练,保存模型与日志
        Args:
            train_loader: 训练数据的DataLoader
            valid_loader: 验证数据的Dataloader
            fold: 当前跑的是第几折
        '''
        optimizer = optim.Adam(self.model.module.parameters(), self.lr, weight_decay=self.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epoch+10)
        global_step = 0

        for epoch in range(self.epoch):
            epoch += 1
            epoch_loss = 0
            self.model.train(True)

            tbar = tqdm.tqdm(train_loader)
            for i, samples in enumerate(tbar):
                # 样本为空则跳过
                if len(samples) == 0:
                    continue
                images, masks = samples[0], samples[1]
                # 网络的前向传播与反向传播,损失函数中包含了sigmoid函数
                masks_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(masks, masks_predict, self.criterion)
                epoch_loss += loss.item()
                self.solver.backword(optimizer, loss)

                # 保存到tensorboard,每一步存储一个
                self.writer.add_scalar('train_loss', loss.item(), global_step+i)
                params_groups_lr = str()
                for group_ind, param_group in enumerate(optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'params_group_%d' % (group_ind) + ': %.12f, ' % (param_group['lr'])
                descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (self.fold, loss.item(), params_groups_lr)
                tbar.set_description(desc=descript)

            # 每一个epoch完毕之后,执行学习率衰减
            lr_scheduler.step()
            global_step += len(train_loader)

            # Print the log info
            print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, epoch_loss/len(tbar)))

            # 验证模型
            loss_valid, dice_valid, iou_valid = self.validation(valid_loader)
            if dice_valid > self.max_dice_valid: 
                is_best = True
                self.max_dice_valid = dice_valid
            else: is_best = False
            
            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
                'max_dice_valid': self.max_dice_valid,
            }

            self.solver.save_checkpoint(os.path.join(self.model_path, '%s_fold%d.pth' % (self.model_name, self.fold)), state, is_best)
            self.writer.add_scalar('valid_loss', loss_valid, epoch)
            self.writer.add_scalar('valid_dice', dice_valid, epoch)

    def validation(self, valid_loader):
        ''' 完成模型的验证过程

        Args:
            valid_loader: 验证数据的Dataloader
        '''
        self.model.eval()
        meter = Meter()
        tbar = tqdm.tqdm(valid_loader)
        loss_sum = 0
        
        with torch.no_grad(): 
            for i, samples in enumerate(tbar):
                if len(samples) == 0:
                    continue
                images, masks = samples[0], samples[1]
                # 完成网络的前向传播
                masks_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(masks, masks_predict, self.criterion)
                loss_sum += loss.item()
                
                # 注意,损失函数中包含sigmoid函数,meter.update中也包含了sigmoid函数
                # masks_predict_binary = torch.sigmoid(masks_predict) > 0.5
                meter.update(masks, masks_predict.detach().cpu())

                descript = "Val Loss: {:.7f}".format(loss.item())
                tbar.set_description(desc=descript)
        loss_mean = loss_sum/len(tbar)

        dices, iou = meter.get_metrics()
        dice, dice_neg, dice_pos = dices
        print("IoU: %0.4f | dice: %0.4f | dice_neg: %0.4f | dice_pos: %0.4f" % (iou, dice, dice_neg, dice_pos))
        return loss_mean, dice, iou
    
    def load_weight(self, weight_path):
        """加载权重
        """
        pretrained_state_dict = torch.load(weight_path)['state_dict']
        model_state_dict = self.model.module.state_dict()
        pretrained_state_dict = {k : v for k, v in pretrained_state_dict.items() if k in model_state_dict}
        model_state_dict.update(pretrained_state_dict)
        print('Loading weight from %s' % weight_path)
        self.model.module.load_state_dict(model_state_dict)
Ejemplo n.º 3
0
class TrainVal:
    def __init__(self, config, fold):
        """
        Args:
            config: 配置参数
            fold: 当前为第几折
        """
        self.config = config
        self.fold = fold
        self.epoch = config.epoch
        self.num_classes = config.num_classes
        self.lr_scheduler = config.lr_scheduler
        print('USE LOSS: {}'.format(config.loss_name))

        # 加载模型
        prepare_model = PrepareModel()
        self.model = prepare_model.create_local_attention_model(
            model_type=config.model_type,
            classes_num=self.num_classes,
            last_stride=2,
            droprate=0)

        # 得到最新产生的权重文件
        weight_path = os.path.join('checkpoints', config.model_type)
        lists = os.listdir(weight_path)  # 获得文件夹内所有文件
        lists.sort(
            key=lambda fn: os.path.getmtime(weight_path + '/' + fn))  # 排序
        weight_path = os.path.join(weight_path, lists[-1], 'model_best.pth')

        # 加载之前训练的权重
        pretrained_dict = torch.load(weight_path)['state_dict']
        model_dict = self.model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }  # filter out unnecessary keys
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)
        print('Successfully Loaded from %s' % weight_path)

        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            self.model = self.model.cuda()

        # 加载优化器
        self.optimizer = prepare_model.create_optimizer(
            config.model_type, self.model, config)

        # 加载衰减策略
        self.exp_lr_scheduler = prepare_model.create_lr_scheduler(
            self.lr_scheduler,
            self.optimizer,
            step_size=config.lr_step_size,
            restart_step=config.restart_step,
            multi_step=config.multi_step)

        # 加载损失函数
        self.criterion = Loss(config.model_type, config.loss_name,
                              self.num_classes)

        # 实例化实现各种子函数的 solver 类
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.solver = Solver(self.model, self.device)

        # log初始化
        self.writer, self.time_stamp = self.init_log()
        self.model_path = os.path.join(self.config.save_path,
                                       self.config.model_type, self.time_stamp)

        # 初始化分类度量准则类
        with open("online-service/model/label_id_name.json",
                  'r',
                  encoding='utf-8') as json_file:
            self.class_names = list(json.load(json_file).values())
        self.classification_metric = ClassificationMetric(
            self.class_names, self.model_path)

        self.max_accuracy_valid = 0

    def train(self, train_loader, valid_loader):
        """ 完成模型的训练,保存模型与日志
        Args:
            train_loader: 训练数据的DataLoader
            valid_loader: 验证数据的Dataloader
        """
        global_step = 0
        max_accuracy_valid = 0
        for epoch in range(self.epoch):
            self.model.train()
            epoch += 1
            images_number, epoch_corrects = 0, 0

            tbar = tqdm.tqdm(train_loader)
            for i, (images, labels) in enumerate(tbar):
                # 网络的前向传播与反向传播
                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels_predict, labels,
                                            self.criterion)
                self.solver.backword(self.optimizer, loss)

                images_number += images.size(0)
                epoch_corrects += self.model.module.get_classify_result(
                    labels_predict, labels, self.device).sum()
                train_acc_iteration = self.model.module.get_classify_result(
                    labels_predict, labels, self.device).mean()

                # 保存到tensorboard,每一步存储一个
                descript = self.criterion.record_loss_iteration(
                    self.writer.add_scalar, global_step + i)
                self.writer.add_scalar('TrainAccIteration',
                                       train_acc_iteration, global_step + i)

                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        self.optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'params_group_%d' % group_ind + ': %.12f, ' % param_group[
                        'lr']

                descript = '[Train Fold {}][epoch: {}/{}][Lr :{}][Acc: {:.4f}]'.format(
                    self.fold, epoch, self.epoch, params_groups_lr,
                    train_acc_iteration) + descript

                tbar.set_description(desc=descript)

            # 写到tensorboard中
            epoch_acc = epoch_corrects / images_number
            self.writer.add_scalar('TrainAccEpoch', epoch_acc, epoch)
            self.writer.add_scalar('Lr', self.optimizer.param_groups[0]['lr'],
                                   epoch)
            descript = self.criterion.record_loss_epoch(
                len(train_loader), self.writer.add_scalar, epoch)

            # Print the log info
            print('[Finish epoch: {}/{}][Average Acc: {:.4}]'.format(
                epoch, self.epoch, epoch_acc) + descript)

            # 验证模型
            val_accuracy, val_loss, is_best = self.validation(valid_loader)

            # 保存参数
            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
                'max_score': max_accuracy_valid
            }
            self.solver.save_checkpoint(
                os.path.join(
                    self.model_path,
                    '%s_fold%d.pth' % (self.config.model_type, self.fold)),
                state, is_best)

            # 写到tensorboard中
            self.writer.add_scalar('ValidLoss', val_loss, epoch)
            self.writer.add_scalar('ValidAccuracy', val_accuracy, epoch)

            # 每一个epoch完毕之后,执行学习率衰减
            if self.lr_scheduler == 'ReduceLR':
                self.exp_lr_scheduler.step(val_loss)
            else:
                self.exp_lr_scheduler.step()
            global_step += len(train_loader)

    def validation(self, valid_loader):
        tbar = tqdm.tqdm(valid_loader)
        self.model.eval()
        labels_predict_all, labels_all = np.empty(shape=(0, )), np.empty(
            shape=(0, ))
        epoch_loss = 0
        with torch.no_grad():
            for i, (_, images, labels) in enumerate(tbar):
                # 网络的前向传播
                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels_predict, labels,
                                            self.criterion)

                epoch_loss += loss

                # 先经过softmax函数,再经过argmax函数
                labels_predict = F.softmax(labels_predict, dim=1)
                labels_predict = torch.argmax(labels_predict,
                                              dim=1).detach().cpu().numpy()

                labels_predict_all = np.concatenate(
                    (labels_predict_all, labels_predict))
                labels_all = np.concatenate((labels_all, labels))

                descript = '[Valid][Loss: {:.4f}]'.format(loss)
                tbar.set_description(desc=descript)

            classify_report, my_confusion_matrix, acc_for_each_class, oa, average_accuracy, kappa = \
                self.classification_metric.get_metric(
                    labels_all,
                    labels_predict_all
                )

            if oa > self.max_accuracy_valid:
                is_best = True
                self.max_accuracy_valid = oa
                self.classification_metric.draw_cm_and_save_result(
                    classify_report, my_confusion_matrix, acc_for_each_class,
                    oa, average_accuracy, kappa)
            else:
                is_best = False

            print('OA:{}, AA:{}, Kappa:{}'.format(oa, average_accuracy, kappa))

            return oa, epoch_loss / len(tbar), is_best

    def init_log(self):
        # 保存配置信息和初始化tensorboard
        TIMESTAMP = "log-{0:%Y-%m-%dT%H-%M-%S}-localAtt".format(
            datetime.datetime.now())
        log_dir = os.path.join(self.config.save_path, self.config.model_type,
                               TIMESTAMP)
        writer = SummaryWriter(log_dir=log_dir)
        with codecs.open(os.path.join(log_dir, 'config.json'), 'w',
                         "utf-8") as json_file:
            json.dump({k: v
                       for k, v in config._get_kwargs()},
                      json_file,
                      ensure_ascii=False)

        seed = int(time.time())
        seed_torch(seed)
        with open(os.path.join(log_dir, 'seed.pkl'), 'wb') as f:
            pickle.dump({'seed': seed}, f, -1)

        return writer, TIMESTAMP
class TrainBaseline(object):
    def __init__(self, config, num_classes, train_triplet=False):
        """

        :param config: 配置参数
        :param num_classes: 训练集的类别数;类型为int
        :param train_triplet: 是否只训练triplet损失;类型为bool
        """
        self.num_classes = num_classes

        self.model_name = config.model_name
        self.last_stride = config.last_stride
        self.num_gpus = torch.cuda.device_count()
        print('Using {} GPUS'.format(self.num_gpus))
        print('NUM_CLASS: {}'.format(self.num_classes))
        print('USE LOSS: {}'.format(config.selected_loss))

        # 加载模型,只要有GPU,则使用DataParallel函数,当GPU有多个GPU时,调用sync_bn函数
        self.model = get_model(self.model_name, self.num_classes, self.last_stride)
        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            if self.num_gpus > 1:
                self.model = convert_model(self.model)
            self.model = self.model.cuda()

        # 加载超参数
        self.epoch = config.epoch

        # 实例化实现各种子函数的 solver 类
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.solver = Solver(self.model, self.device)

        # 加载损失函数
        self.criterion = Loss(self.model_name, config.selected_loss, config.margin, self.num_classes)

        # 加载优化函数
        self.optim = get_optimizer(config, self.model)

        # 加载学习率衰减策略
        self.scheduler = get_scheduler(config, self.optim)

        # 创建保存权重的路径
        self.model_path = os.path.join(config.save_path, config.model_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        # 如果只训练Triplet损失
        if train_triplet:
            self.solver.load_checkpoint(os.path.join(self.model_path, '{}.pth'.format(self.model_name)))

        # 保存json文件和初始化tensorboard
        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now())
        self.writer = SummaryWriter(log_dir=os.path.join(self.model_path, TIMESTAMP))
        with codecs.open(self.model_path + '/' + TIMESTAMP + '.json', 'w', "utf-8") as json_file:
            json.dump({k: v for k, v in config._get_kwargs()}, json_file, ensure_ascii=False)

        # 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
        self.seed = int(time.time())
        seed_torch(self.seed)
        with open(self.model_path + '/' + TIMESTAMP + '.pkl', 'wb') as f:
            pickle.dump({'seed': self.seed}, f, -1)

    def train(self, train_loader):
        """ 完成模型的训练,保存模型与日志

        :param train_loader: 训练集的Dataloader
        :return: None
        """

        global_step = 0

        for epoch in range(self.epoch):
            epoch += 1
            self.model.train()
            images_number, epoch_corrects, index = 0, 0, 0

            tbar = tqdm.tqdm(train_loader)
            for index, (images, labels) in enumerate(tbar):
                # 网络的前向传播与反向传播
                outputs = self.solver.forward((images, labels))
                loss = self.solver.cal_loss(outputs, labels, self.criterion)
                self.solver.backword(self.optim, loss)

                images_number += images.size(0)
                epoch_corrects += self.model.module.get_classify_result(outputs, labels, self.device).sum()
                train_acc_iteration = self.model.module.get_classify_result(outputs, labels, self.device).mean() * 100

                # 保存到tensorboard,每一步存储一个
                global_step += 1
                descript = self.criterion.record_loss_iteration(self.writer.add_scalar, global_step)
                self.writer.add_scalar('TrainAccIteration', train_acc_iteration, global_step)

                descript = '[Train][epoch: {}/{}][Lr :{:.7f}][Acc: {:.2f}]'.format(epoch, self.epoch,
                                                                               self.scheduler.get_lr()[1],
                                                                               train_acc_iteration) + descript
                tbar.set_description(desc=descript)

            # 每一个epoch完毕之后,执行学习率衰减
            self.scheduler.step()

            # 写到tensorboard中
            epoch_acc = epoch_corrects / images_number * 100
            self.writer.add_scalar('TrainAccEpoch', epoch_acc, epoch)
            self.writer.add_scalar('Lr', self.scheduler.get_lr()[1], epoch)
            descript = self.criterion.record_loss_epoch(index, self.writer.add_scalar, epoch)

            # Print the log info
            print('[Finish epoch: {}/{}][Average Acc: {:.2}]'.format(epoch, self.epoch, epoch_acc) + descript)

            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
            }

            self.solver.save_checkpoint(
                os.path.join(self.model_path, '{}.pth'.format(self.model_name)), state, False)
Ejemplo n.º 5
0
class TrainVal:
    def __init__(self, config, fold):
        """
        Args:
            config: 配置参数
            fold: 当前为第几折
        """
        self.config = config
        self.fold = fold
        self.epoch = config.epoch
        self.num_classes = config.num_classes
        self.lr_scheduler = config.lr_scheduler
        self.save_interval = 10
        self.cut_mix = config.cut_mix
        self.beta = config.beta
        self.cutmix_prob = config.cutmix_prob
        self.auto_aug = config.auto_aug

        # 多尺度
        self.image_size = config.image_size
        self.multi_scale = config.multi_scale
        self.val_multi_scale = config.val_multi_scale
        self.multi_scale_size = config.multi_scale_size
        self.multi_scale_interval = config.multi_scale_interval
        # 稀疏训练
        self.sparsity = config.sparsity
        self.sparsity_scale = config.sparsity_scale
        self.penalty_type = config.penalty_type
        self.selected_labels = config.selected_labels
        if self.auto_aug:
            print('@ Using AutoAugment.')
        if self.cut_mix:
            print('@ Using cut mix.')
        if self.multi_scale:
            print('@ Using multi scale training.')
        print('@ Using LOSS: {}'.format(config.loss_name))

        # 加载模型
        prepare_model = PrepareModel()
        self.model = prepare_model.create_model(model_type=config.model_type,
                                                classes_num=self.num_classes,
                                                drop_rate=config.drop_rate,
                                                pretrained=True,
                                                bn_to_gn=config.bn_to_gn)
        if config.weight_path:
            self.model = prepare_model.load_chekpoint(self.model,
                                                      config.weight_path)

        # 稀疏训练
        self.sparsity_train = None
        if config.sparsity:
            print('@ Using sparsity training.')
            self.sparsity_train = Sparsity(self.model,
                                           sparsity_scale=self.sparsity_scale,
                                           penalty_type=self.penalty_type)

        # l1正则化
        self.l1_regular = config.l1_regular
        self.l1_decay = config.l1_decay
        if self.l1_regular:
            print('@ Using l1_regular')
            self.l1_reg_loss = Regularization(self.model,
                                              weight_decay=self.l1_decay,
                                              p=1)

        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            self.model = self.model.cuda()

        # 加载优化器
        self.optimizer = prepare_model.create_optimizer(
            config.model_type, self.model, config)

        # 加载衰减策略
        self.exp_lr_scheduler = prepare_model.create_lr_scheduler(
            self.lr_scheduler,
            self.optimizer,
            step_size=config.lr_step_size,
            restart_step=config.restart_step,
            multi_step=config.multi_step,
            warmup=config.warmup,
            multiplier=config.multiplier,
            warmup_epoch=config.warmup_epoch,
            delay_epoch=config.delay_epoch)

        # 加载损失函数
        self.criterion = Loss(config.model_type, config.loss_name,
                              self.num_classes)

        # 实例化实现各种子函数的 solver 类
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.solver = Solver(self.model, self.device)

        # log初始化
        self.writer, self.time_stamp = self.init_log()
        self.model_path = os.path.join(self.config.save_path,
                                       self.config.model_type, self.time_stamp)

        # 初始化分类度量准则类
        with open("online-service/model/label_id_name.json",
                  'r',
                  encoding='utf-8') as json_file:
            self.class_names = list(json.load(json_file).values())
        self.classification_metric = ClassificationMetric(
            self.class_names, self.model_path)

        self.max_accuracy_valid = 0

    def train(self, train_loader, valid_loader):
        """ 完成模型的训练,保存模型与日志
        Args:
            train_loader: 训练数据的DataLoader
            valid_loader: 验证数据的Dataloader
        """
        global_step = 0
        for epoch in range(self.epoch):
            self.model.train()
            epoch += 1
            images_number, epoch_corrects = 0, 0

            tbar = tqdm.tqdm(train_loader)
            image_size = self.image_size
            l1_regular_loss = 0
            loss_with_l1_regular = 0
            for i, (images, labels) in enumerate(tbar):
                if self.multi_scale:
                    if i % self.multi_scale_interval == 0:
                        image_size = random.choice(self.multi_scale_size)
                    images = multi_scale_transforms(image_size,
                                                    images,
                                                    auto_aug=self.auto_aug)
                if self.cut_mix:
                    # 使用cut_mix
                    r = np.random.rand(1)
                    if self.beta > 0 and r < self.cutmix_prob:
                        images, labels_a, labels_b, lam = generate_mixed_sample(
                            self.beta, images, labels)
                        labels_predict = self.solver.forward(images)
                        loss = self.solver.cal_loss_cutmix(
                            labels_predict, labels_a, labels_b, lam,
                            self.criterion)
                    else:
                        # 网络的前向传播
                        labels_predict = self.solver.forward(images)
                        loss = self.solver.cal_loss(labels_predict, labels,
                                                    self.criterion)
                else:
                    # 网络的前向传播
                    labels_predict = self.solver.forward(images)
                    loss = self.solver.cal_loss(labels_predict, labels,
                                                self.criterion)

                if self.l1_regular:
                    current_l1_regular_loss = self.l1_reg_loss(self.model)
                    loss += current_l1_regular_loss
                    l1_regular_loss += current_l1_regular_loss.item()
                    loss_with_l1_regular += loss.item()
                self.solver.backword(self.optimizer,
                                     loss,
                                     sparsity=self.sparsity_train)

                images_number += images.size(0)
                epoch_corrects += self.model.module.get_classify_result(
                    labels_predict, labels, self.device).sum()
                train_acc_iteration = self.model.module.get_classify_result(
                    labels_predict, labels, self.device).mean()

                # 保存到tensorboard,每一步存储一个
                descript = self.criterion.record_loss_iteration(
                    self.writer.add_scalar, global_step + i)
                self.writer.add_scalar('TrainAccIteration',
                                       train_acc_iteration, global_step + i)

                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        self.optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'pg_%d' % group_ind + ': %.8f, ' % param_group[
                        'lr']

                descript = '[Train Fold {}][epoch: {}/{}][image_size: {}][Lr :{}][Acc: {:.4f}]'.format(
                    self.fold, epoch, self.epoch, image_size, params_groups_lr,
                    train_acc_iteration) + descript
                if self.l1_regular:
                    descript += '[L1RegularLoss: {:.4f}][Loss: {:.4f}]'.format(
                        current_l1_regular_loss.item(), loss.item())
                tbar.set_description(desc=descript)

            # 写到tensorboard中
            epoch_acc = epoch_corrects / images_number
            self.writer.add_scalar('TrainAccEpoch', epoch_acc, epoch)
            self.writer.add_scalar('Lr', self.optimizer.param_groups[0]['lr'],
                                   epoch)
            if self.l1_regular:
                l1_regular_loss_epoch = l1_regular_loss / len(train_loader)
                loss_with_l1_regular_epoch = loss_with_l1_regular / len(
                    train_loader)
                self.writer.add_scalar('TrainL1RegularLoss',
                                       l1_regular_loss_epoch, epoch)
                self.writer.add_scalar('TrainLossWithL1Regular',
                                       loss_with_l1_regular_epoch, epoch)
            descript = self.criterion.record_loss_epoch(
                len(train_loader), self.writer.add_scalar, epoch)

            # Print the log info
            print('[Finish epoch: {}/{}][Average Acc: {:.4}]'.format(
                epoch, self.epoch, epoch_acc) + descript)

            # 验证模型
            val_accuracy, val_loss, is_best = self.validation(
                valid_loader, self.val_multi_scale)

            # 保存参数
            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
                'max_score': self.max_accuracy_valid
            }
            self.solver.save_checkpoint(
                os.path.join(
                    self.model_path,
                    '%s_fold%d.pth' % (self.config.model_type, self.fold)),
                state, is_best)

            if epoch % self.save_interval == 0:
                self.solver.save_checkpoint(
                    os.path.join(
                        self.model_path, '%s_epoch%d_fold%d.pth' %
                        (self.config.model_type, epoch, self.fold)), state,
                    False)

            # 写到tensorboard中
            self.writer.add_scalar('ValidLoss', val_loss, epoch)
            self.writer.add_scalar('ValidAccuracy', val_accuracy, epoch)

            # 每一个epoch完毕之后,执行学习率衰减
            if self.lr_scheduler == 'ReduceLR':
                self.exp_lr_scheduler.step(metrics=val_accuracy)
            else:
                self.exp_lr_scheduler.step()
            global_step += len(train_loader)
        print('BEST ACC:{}'.format(self.max_accuracy_valid))
        source_path = os.path.join(self.model_path, 'model_best.pth')
        target_path = os.path.join(self.config.save_path,
                                   self.config.model_type, 'backup',
                                   'model_best.pth')
        print('Copy %s to %s' % (source_path, target_path))
        shutil.copy(source_path, target_path)

    def validation(self, valid_loader, multi_scale=False):
        self.model.eval()
        labels_predict_all, labels_all = np.empty(shape=(0, )), np.empty(
            shape=(0, ))
        epoch_loss = 0
        with torch.no_grad():
            if multi_scale:
                multi_oa = []
                for image_size in self.multi_scale_size:
                    tbar = tqdm.tqdm(valid_loader)
                    # 对于每一个尺度都计算准确率
                    for i, (_, images, labels) in enumerate(tbar):
                        images = multi_scale_transforms(image_size,
                                                        images,
                                                        auto_aug=False)
                        # 网络的前向传播
                        labels_predict = self.solver.forward(images)
                        loss = self.solver.cal_loss(labels_predict, labels,
                                                    self.criterion)

                        epoch_loss += loss

                        # 先经过softmax函数,再经过argmax函数
                        labels_predict = F.softmax(labels_predict, dim=1)
                        labels_predict = torch.argmax(
                            labels_predict, dim=1).detach().cpu().numpy()

                        labels_predict_all = np.concatenate(
                            (labels_predict_all, labels_predict))
                        labels_all = np.concatenate((labels_all, labels))

                        descript = '[Valid][Loss: {:.4f}]'.format(loss)
                        tbar.set_description(desc=descript)

                    classify_report, my_confusion_matrix, acc_for_each_class, oa, average_accuracy, kappa = \
                        self.classification_metric.get_metric(
                            labels_all,
                            labels_predict_all
                        )
                    multi_oa.append(oa)
                oa = np.asarray(multi_oa).mean()
            else:
                tbar = tqdm.tqdm(valid_loader)
                for i, (_, images, labels) in enumerate(tbar):
                    # 网络的前向传播
                    labels_predict = self.solver.forward(images)
                    loss = self.solver.cal_loss(labels_predict, labels,
                                                self.criterion)

                    epoch_loss += loss

                    # 先经过softmax函数,再经过argmax函数
                    labels_predict = F.softmax(labels_predict, dim=1)
                    labels_predict = torch.argmax(
                        labels_predict, dim=1).detach().cpu().numpy()

                    labels_predict_all = np.concatenate(
                        (labels_predict_all, labels_predict))
                    labels_all = np.concatenate((labels_all, labels))

                    descript = '[Valid][Loss: {:.4f}]'.format(loss)
                    tbar.set_description(desc=descript)

                classify_report, my_confusion_matrix, acc_for_each_class, oa, average_accuracy, kappa = \
                    self.classification_metric.get_metric(
                        labels_all,
                        labels_predict_all
                    )

            if oa > self.max_accuracy_valid:
                is_best = True
                self.max_accuracy_valid = oa
                if not self.selected_labels:
                    # 只有在未指定训练类别时才画混淆矩阵,否则会出错
                    self.classification_metric.draw_cm_and_save_result(
                        classify_report, my_confusion_matrix,
                        acc_for_each_class, oa, average_accuracy, kappa)
            else:
                is_best = False

            print('OA:{}, AA:{}, Kappa:{}'.format(oa, average_accuracy, kappa))

            return oa, epoch_loss / len(tbar), is_best

    def init_log(self):
        # 保存配置信息和初始化tensorboard
        TIMESTAMP = "log-{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now())
        log_dir = os.path.join(self.config.save_path, self.config.model_type,
                               TIMESTAMP)
        writer = SummaryWriter(log_dir=log_dir)
        with codecs.open(os.path.join(log_dir, 'config.json'), 'w',
                         "utf-8") as json_file:
            json.dump({k: v
                       for k, v in config._get_kwargs()},
                      json_file,
                      ensure_ascii=False)

        seed = int(time.time())
        seed_torch(seed)
        with open(os.path.join(log_dir, 'seed.pkl'), 'wb') as f:
            pickle.dump({'seed': seed}, f, -1)

        return writer, TIMESTAMP
Ejemplo n.º 6
0
class TrainVal():
    def __init__(self, config):
        # self.model = ClassifyResNet(config.model_name, config.class_num, training=True)
        self.model = models.alexnet(pretrained=True)
        # self.model = VGG_19(config.class_num)
        # self.model = models.vgg19_bn(pretrained=True)
        # self.model = models.resnet50(pretrained=True)

        # # freeze model parameters
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.classifier[6] = nn.Sequential(
            nn.Linear(4096, config.class_num))
        # # for param in self.model.feature.parameters():
        # #     param.requires_grad = True
        # # for param in self.model.logit.parameters():
        # #     param.requires_grad = True
        for param in self.model.classifier.parameters():
            param.requires_grad = True

        if torch.cuda.is_available():
            self.device = torch.device("cuda:%i" % config.device[0])
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=config.device)
            self.model = self.model.to(self.device)
        else:
            self.device = torch.device("cpu")
            self.model = self.model.to(self.device)

        # 加载超参数
        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch

        # 实例化实现各种子函数的 solver 类
        self.solver = Solver(self.model, self.device)

        # 加载损失函数
        self.criterion = ClassifyLoss(weight=[0.8, 0.2])

        # 创建保存权重的路径
        self.TIME = "{0:%Y-%m-%dT%H-%M-%S}-classify".format(
            datetime.datetime.now())
        self.model_path = os.path.join(config.root, config.save_path,
                                       config.model_name, self.TIME)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        self.max_accuracy_valid = 0
        # 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
        self.seed = int(time.time())
        # self.seed = 1570421136
        seed_torch(self.seed)

    def train(self, dataloaders):
        # optimizer = optim.Adam(self.model.module.parameters(), self.lr, weight_decay=self.weight_decay)
        optimizer = optim.SGD(self.model.module.parameters(),
                              self.lr,
                              momentum=0.9,
                              weight_decay=self.weight_decay,
                              nesterov=True)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.epoch + 50)
        global_step = 0

        ohe = OneHotEncoder()
        ohe.fit([[0], [1]])

        for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
            if fold_index != 0:
                continue
            # 保存json文件和初始化tensorboard
            TIMESTAMP = '-fold'.join([self.TIME, str(fold_index)])
            self.writer = SummaryWriter(
                log_dir=os.path.join(self.model_path, TIMESTAMP))
            with codecs.open(
                    os.path.join(self.model_path, TIMESTAMP, TIMESTAMP) +
                    '.json', 'w', "utf-8") as json_file:
                json.dump({k: v
                           for k, v in config._get_kwargs()},
                          json_file,
                          ensure_ascii=False)

            with open(
                    os.path.join(self.model_path, TIMESTAMP, TIMESTAMP) +
                    '.pkl', 'wb') as f:
                pickle.dump({'seed': self.seed}, f, -1)

            for epoch in range(self.epoch):
                epoch += 1
                epoch_loss = 0
                self.model.train(True)

                tbar = tqdm(train_loader)
                for i, (images, labels) in enumerate(tbar):
                    labels = torch.from_numpy(
                        ohe.transform(labels.reshape(-1, 1)).toarray())
                    # 网络的前向传播与反向传播
                    labels_predict = self.solver.forward(images)
                    labels_predict = labels_predict.unsqueeze(dim=2).unsqueeze(
                        dim=3)
                    loss = self.solver.cal_loss(labels, labels_predict,
                                                self.criterion)
                    # loss = F.cross_entropy(labels_predict[0], labels)
                    epoch_loss += loss.item()
                    self.solver.backword(optimizer, loss)

                    # 保存到tensorboard,每一步存储一个
                    self.writer.add_scalar('train_loss', loss.item(),
                                           global_step + i)
                    self.writer.add_images('my_image_batch',
                                           images[:10].cpu().detach().numpy(),
                                           global_step + i)
                    params_groups_lr = str()
                    for group_ind, param_group in enumerate(
                            optimizer.param_groups):
                        params_groups_lr = params_groups_lr + 'params_group_%d' % (
                            group_ind) + ': %.12f, ' % (param_group['lr'])
                    descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (
                        fold_index, loss.item(), params_groups_lr)
                    tbar.set_description(desc=descript)

                # 每一个epoch完毕之后,执行学习率衰减
                lr_scheduler.step()
                global_step += len(train_loader)

                # 验证模型
                class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \
                    self.validation(valid_loader)

                # Print the log info
                print(
                    'Finish Epoch [%d/%d] | Average training Loss: %.7f | Average validation Loss: %.7f | Total accuracy: %0.4f |'
                    % (epoch, self.epoch, epoch_loss / len(tbar), loss_valid,
                       accuracy))

                if accuracy > self.max_accuracy_valid:
                    is_best = True
                    self.max_accuracy_valid = accuracy
                else:
                    is_best = False

                state = {
                    'epoch': epoch,
                    'state_dict': self.model.module.state_dict(),
                    'max_accuracy_valid': self.max_accuracy_valid,
                }

                self.solver.save_checkpoint(
                    os.path.join(self.model_path, TIMESTAMP,
                                 TIMESTAMP + '.pth'), state, is_best,
                    self.max_accuracy_valid)
                self.writer.add_scalar('valid_loss', loss_valid, epoch)
                self.writer.add_scalar('valid_accuracy', accuracy, epoch)
                self.writer.add_scalar('valid_class_0_accuracy',
                                       class_accuracy[0], epoch)
                self.writer.add_scalar('valid_class_1_accuracy',
                                       class_accuracy[1], epoch)
                # self.writer.add_scalar('valid_class_2_accuracy', class_accuracy[2], epoch)

    def validation(self, valid_loader):
        self.model.eval()
        meter = Meter()
        tbar = tqdm(valid_loader)
        loss_sum = 0

        ohe = OneHotEncoder()
        ohe.fit([[0], [1]])

        with torch.no_grad():
            for i, (images, labels) in enumerate(tbar):
                labels = torch.from_numpy(
                    ohe.transform(labels.reshape(-1, 1)).toarray())
                # 完成网络的前向传播
                labels_predict = self.solver.forward(images)
                labels_predict = labels_predict.unsqueeze(dim=2).unsqueeze(
                    dim=3)
                loss = self.solver.cal_loss(labels, labels_predict,
                                            self.criterion)
                # loss = F.cross_entropy(labels_predict[0], labels)
                loss_sum += loss.item()

                meter.update(labels, labels_predict.cpu())

                descript = "Val Loss: {:.7f}".format(loss.item())
                tbar.set_description(desc=descript)
        loss_mean = loss_sum / len(tbar)

        class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy = meter.get_metrics(
        )
        print(
            "Class_0_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f | \n"
            "Class_1_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f |"
            %
            (class_accuracy[0], class_pos_accuracy[0], class_neg_accuracy[0],
             class_accuracy[1], class_pos_accuracy[1], class_neg_accuracy[1]))
        # print("Class_0_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f | \n"
        #       "Class_1_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f | \n"
        #       "Class_2_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f |" %
        #       (class_accuracy[0], class_pos_accuracy[0], class_neg_accuracy[0],
        #        class_accuracy[1], class_pos_accuracy[1], class_neg_accuracy[1],
        #        class_accuracy[2], class_pos_accuracy[2], class_neg_accuracy[2]))
        return class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_mean
Ejemplo n.º 7
0
class TrainVal:
    def __init__(self, config, fold, train_labels_number):
        """
        Args:
            config: 配置参数
            fold: int, 当前为第几折
            train_labels_number: list, 某一折的[number_class0, number__class1, ...]
        """
        self.config = config
        self.fold = fold
        self.epoch = config.epoch
        self.num_classes = config.num_classes
        self.lr_scheduler = config.lr_scheduler
        self.save_interval = 100
        self.cut_mix = config.cut_mix
        self.beta = config.beta
        self.cutmix_prob = config.cutmix_prob

        self.image_size = config.image_size
        self.multi_scale = config.multi_scale
        self.multi_scale_size = config.multi_scale_size
        self.multi_scale_interval = config.multi_scale_interval
        if self.cut_mix:
            print('Using cut mix.')
        if self.multi_scale:
            print('Using multi scale training.')
        print('USE LOSS: {}'.format(config.loss_name))

        # 加载模型
        prepare_model = PrepareModel()
        self.model = prepare_model.create_model(model_type=config.model_type,
                                                classes_num=self.num_classes,
                                                drop_rate=config.drop_rate,
                                                pretrained=True,
                                                bn_to_gn=config.bn_to_gn)
        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            self.model = self.model.cuda()

        # 加载优化器
        self.optimizer = prepare_model.create_optimizer(
            config.model_type, self.model, config)

        # 加载衰减策略
        self.exp_lr_scheduler = prepare_model.create_lr_scheduler(
            self.lr_scheduler,
            self.optimizer,
            step_size=config.lr_step_size,
            restart_step=config.restart_step,
            multi_step=config.multi_step)

        # 加载损失函数
        self.criterion = Loss(config.model_type, config.loss_name,
                              self.num_classes, train_labels_number,
                              config.beta_CB, config.gamma)

        # 实例化实现各种子函数的 solver 类
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.solver = Solver(self.model, self.device)
        if config.restore:
            weight_path = os.path.join('checkpoints', config.model_type)
            if config.restore == 'last':
                lists = os.listdir(weight_path)  # 获得文件夹内所有文件
                lists.sort(key=lambda fn: os.path.getmtime(weight_path + '/' +
                                                           fn))  # 按照最近修改时间排序
                weight_path = os.path.join(weight_path, lists[-1],
                                           'model_best.pth')
            else:
                weight_path = os.path.join(weight_path, config.restore,
                                           'model_best.pth')
            self.solver.load_checkpoint(weight_path)

        # log初始化
        self.writer, self.time_stamp = self.init_log()
        self.model_path = os.path.join(self.config.train_url,
                                       self.config.model_type, self.time_stamp)

        # 初始化分类度量准则类
        with open("online-service/model/label_id_name.json",
                  'r',
                  encoding='utf-8') as json_file:
            self.class_names = list(json.load(json_file).values())
        self.classification_metric = ClassificationMetric(self.class_names,
                                                          self.model_path,
                                                          text_flag=0)

        self.max_accuracy_valid = 0

    def train(self, train_loader, valid_loader):
        """ 完成模型的训练,保存模型与日志
        Args:
            train_loader: 训练数据的DataLoader
            valid_loader: 验证数据的Dataloader
        """
        global_step = 0
        for epoch in range(self.epoch):
            self.model.train()
            epoch += 1
            images_number, epoch_corrects = 0, 0

            tbar = tqdm.tqdm(train_loader)
            image_size = self.image_size
            for i, (_, images, labels) in enumerate(tbar):
                if self.multi_scale:
                    if i % self.multi_scale_interval == 0:
                        image_size = random.choice(self.multi_scale_size)
                    images = multi_scale_transforms(image_size, images)
                if self.cut_mix:
                    # 使用cut_mix
                    r = np.random.rand(1)
                    if self.beta > 0 and r < self.cutmix_prob:
                        images, labels_a, labels_b, lam = generate_mixed_sample(
                            self.beta, images, labels)
                        labels_predict = self.solver.forward(images)
                        loss = self.solver.cal_loss_cutmix(
                            labels_predict, labels_a, labels_b, lam,
                            self.criterion)
                    else:
                        # 网络的前向传播
                        labels_predict = self.solver.forward(images)
                        loss = self.solver.cal_loss(labels_predict, labels,
                                                    self.criterion)
                else:
                    # 网络的前向传播
                    labels_predict = self.solver.forward(images)
                    loss = self.solver.cal_loss(labels_predict, labels,
                                                self.criterion)
                self.solver.backword(self.optimizer, loss)

                images_number += images.size(0)
                epoch_corrects += self.model.module.get_classify_result(
                    labels_predict, labels, self.device).sum()
                train_acc_iteration = self.model.module.get_classify_result(
                    labels_predict, labels, self.device).mean()

                # 保存到tensorboard,每一步存储一个
                descript = self.criterion.record_loss_iteration(
                    self.writer.add_scalar, global_step + i)
                self.writer.add_scalar('TrainAccIteration',
                                       train_acc_iteration, global_step + i)

                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        self.optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'pg_%d' % group_ind + ': %.8f, ' % param_group[
                        'lr']

                descript = '[Train Fold {}][epoch: {}/{}][image_size: {}][Lr :{}][Acc: {:.4f}]'.format(
                    self.fold, epoch, self.epoch, image_size, params_groups_lr,
                    train_acc_iteration) + descript

                # 对于 CyclicLR,要每一步均执行依次学习率衰减
                if self.lr_scheduler == 'CyclicLR':
                    self.exp_lr_scheduler.step()
                    self.writer.add_scalar(
                        'Lr', self.optimizer.param_groups[1]['lr'],
                        global_step + i)

                tbar.set_description(desc=descript)

            # 写到tensorboard中
            epoch_acc = epoch_corrects / images_number
            self.writer.add_scalar('TrainAccEpoch', epoch_acc, epoch)
            if self.lr_scheduler != 'CyclicLR':
                self.writer.add_scalar('Lr',
                                       self.optimizer.param_groups[1]['lr'],
                                       epoch)
            descript = self.criterion.record_loss_epoch(
                len(train_loader), self.writer.add_scalar, epoch)

            # Print the log info
            print('[Finish epoch: {}/{}][Average Acc: {:.4}]'.format(
                epoch, self.epoch, epoch_acc) + descript)

            # 验证模型
            val_accuracy, val_loss, is_best = self.validation(valid_loader)

            # 保存参数
            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
                'max_score': self.max_accuracy_valid
            }
            self.solver.save_checkpoint(
                os.path.join(
                    self.model_path,
                    '%s_fold%d.pth' % (self.config.model_type, self.fold)),
                state, is_best)

            if epoch % self.save_interval == 0:
                self.solver.save_checkpoint(
                    os.path.join(
                        self.model_path, '%s_epoch%d_fold%d.pth' %
                        (self.config.model_type, epoch, self.fold)), state,
                    False)

            # 写到tensorboard中
            self.writer.add_scalar('ValidLoss', val_loss, epoch)
            self.writer.add_scalar('ValidAccuracy', val_accuracy, epoch)

            # 每一个epoch完毕之后,执行学习率衰减
            if self.lr_scheduler == 'ReduceLR':
                self.exp_lr_scheduler.step(val_loss)
            elif self.lr_scheduler != 'CyclicLR':
                self.exp_lr_scheduler.step()
            global_step += len(train_loader)
        print('BEST ACC:{}'.format(self.max_accuracy_valid))

    def validation(self, valid_loader):
        tbar = tqdm.tqdm(valid_loader)
        self.model.eval()
        labels_predict_all, labels_all = np.empty(shape=(0, )), np.empty(
            shape=(0, ))
        epoch_loss = 0
        with torch.no_grad():
            for i, (_, images, labels) in enumerate(tbar):
                # 网络的前向传播
                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels_predict, labels,
                                            self.criterion)

                epoch_loss += loss

                # 先经过softmax函数,再经过argmax函数
                labels_predict = F.softmax(labels_predict, dim=1)
                labels_predict = torch.argmax(labels_predict,
                                              dim=1).detach().cpu().numpy()

                labels_predict_all = np.concatenate(
                    (labels_predict_all, labels_predict))
                labels_all = np.concatenate((labels_all, labels))

                descript = '[Valid][Loss: {:.4f}]'.format(loss)
                tbar.set_description(desc=descript)

            classify_report, my_confusion_matrix, acc_for_each_class, oa, average_accuracy, kappa = \
                self.classification_metric.get_metric(
                    labels_all,
                    labels_predict_all
                )

            if oa > self.max_accuracy_valid:
                is_best = True
                self.max_accuracy_valid = oa
                self.classification_metric.draw_cm_and_save_result(
                    classify_report, my_confusion_matrix, acc_for_each_class,
                    oa, average_accuracy, kappa)
            else:
                is_best = False

            print('OA:{}, AA:{}, Kappa:{}'.format(oa, average_accuracy, kappa))

            return oa, epoch_loss / len(tbar), is_best

    def init_log(self):
        # 保存配置信息和初始化tensorboard
        TIMESTAMP = "log-{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now())
        log_dir = os.path.join(self.config.train_url, self.config.model_type,
                               TIMESTAMP)
        writer = SummaryWriter(log_dir=log_dir)
        with codecs.open(os.path.join(log_dir, 'param.json'), 'w',
                         "utf-8") as json_file:
            json.dump({k: v
                       for k, v in config._get_kwargs()},
                      json_file,
                      ensure_ascii=False)

        seed = int(time.time())
        seed_torch(seed)
        with open(os.path.join(log_dir, 'seed.pkl'), 'wb') as f:
            pickle.dump({'seed': seed}, f, -1)

        return writer, TIMESTAMP
Ejemplo n.º 8
0
class TrainVal():
    def __init__(self, config):
        # self.model = ClassifyResNet(config.model_name, config.class_num, training=True)
        self.model = models.alexnet(pretrained=False)
        # self.model = models.vgg11(pretrained=True)
        # self.model = models.resnet50(pretrained=True)

        # freeze model parameters
        for param in self.model.parameters():
            param.requires_grad = False

        self.model.classifier[6] = nn.Sequential(
            nn.Linear(4096, config.class_num))
        # # for param in self.model.feature.parameters():
        # #     param.requires_grad = True
        # # for param in self.model.logit.parameters():
        # #     param.requires_grad = True
        for param in self.model.classifier.parameters():
            param.requires_grad = True

        # model check
        print(self.model)
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                print("requires_grad: True ", name)
            else:
                print("requires_grad: False ", name)

        self.device = torch.device("cpu")
        if torch.cuda.is_available():
            self.device = torch.device("cuda:%i" % config.device[0])
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=config.device)
        self.model = self.model.to(self.device)

        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch

        self.solver = Solver(self.model, self.device)

        self.criterion = ClassifyLoss()

        self.TIME = "{0:%Y-%m-%dT%H-%M-%S}-classify".format(
            datetime.datetime.now())
        self.model_path = os.path.join(config.root, config.save_path,
                                       config.model_name, self.TIME)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        self.max_accuracy_valid = 0
        self.seed = int(time.time())
        # self.seed = 1570421136
        seed_torch(self.seed)

    def train(self, dataloaders):
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      self.model.module.parameters()),
                               self.lr,
                               weight_decay=self.weight_decay)
        # optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.model.module.parameters()), self.lr,
        #                       momentum=0.9, weight_decay=self.weight_decay, nesterov=True)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.epoch + 50)
        global_step = 1

        for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):

            TIMESTAMP = '-fold'.join([self.TIME, str(fold_index)])
            self.writer = SummaryWriter(
                log_dir=os.path.join(self.model_path, TIMESTAMP))
            with codecs.open(
                    os.path.join(self.model_path, TIMESTAMP, TIMESTAMP) +
                    '.json', 'w', "utf-8") as json_file:
                json.dump({k: v
                           for k, v in config._get_kwargs()},
                          json_file,
                          ensure_ascii=False)

            with open(
                    os.path.join(self.model_path, TIMESTAMP, TIMESTAMP) +
                    '.pkl', 'wb') as f:
                pickle.dump({'seed': self.seed}, f, -1)

            for epoch in range(1, self.epoch + 1):
                epoch += self.epoch * fold_index
                epoch_loss = 0
                self.model.train(True)

                tbar = tqdm(train_loader)
                for i, (images, labels) in enumerate(tbar):
                    labels_predict = self.solver.forward(images)
                    # labels_predict = labels_predict.unsqueeze(dim=2).unsqueeze(dim=3)
                    loss = self.solver.cal_loss(labels, labels_predict,
                                                self.criterion)
                    # loss = F.cross_entropy(labels_predict[0], labels)
                    epoch_loss += loss.item()
                    self.solver.backword(optimizer, loss)

                    self.writer.add_scalar('train_loss', loss.item(),
                                           global_step + i)
                    # self.writer.add_images('my_image_batch', images.cpu().detach().numpy(), global_step + i)
                    params_groups_lr = str()
                    for group_ind, param_group in enumerate(
                            optimizer.param_groups):
                        params_groups_lr = params_groups_lr + 'params_group_%d' % (
                            group_ind) + ': %.12f, ' % (param_group['lr'])
                    descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (
                        fold_index, loss.item(), params_groups_lr)
                    tbar.set_description(desc=descript)

                if epoch % 5 == 0:
                    lr_scheduler.step()
                global_step += len(train_loader)

                class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \
                    self.validation(valid_loader)

                print(
                    'Finish Epoch [%d/%d] | Average training Loss: %.7f | Average validation Loss: %.7f | Total accuracy: %0.4f |'
                    % (epoch, self.epoch * config.n_splits,
                       epoch_loss / len(tbar), loss_valid, accuracy))

                if accuracy > self.max_accuracy_valid:
                    is_best = True
                    self.max_accuracy_valid = accuracy
                else:
                    is_best = False

                state = {
                    'epoch': epoch,
                    'state_dict': self.model.module.state_dict(),
                    'max_accuracy_valid': self.max_accuracy_valid,
                }

                self.solver.save_checkpoint(
                    os.path.join(self.model_path, TIMESTAMP,
                                 TIMESTAMP + '.pth'), state, is_best,
                    self.max_accuracy_valid)
                self.writer.add_scalar('valid_loss', loss_valid, epoch)
                self.writer.add_scalar('valid_accuracy', accuracy, epoch)
                self.writer.add_scalar('valid_class_0_accuracy',
                                       class_accuracy[0], epoch)
                self.writer.add_scalar('valid_class_1_accuracy',
                                       class_accuracy[1], epoch)
                # self.writer.add_scalar('valid_class_2_accuracy', class_accuracy[2], epoch)

    def validation(self, valid_loader):
        self.model.eval()
        meter = Meter()
        tbar = tqdm(valid_loader)
        loss_sum = 0

        with torch.no_grad():
            for i, (images, labels) in enumerate(tbar):
                labels_predict = self.solver.forward(images)
                # print(labels + (labels_predict.cpu() > 0.5).int())
                # labels_predict = labels_predict.unsqueeze(dim=2).unsqueeze(dim=3)
                loss = self.solver.cal_loss(labels, labels_predict,
                                            self.criterion)
                # loss = F.cross_entropy(labels_predict[0], labels)
                loss_sum += loss.item()

                meter.update(labels, labels_predict.cpu())

                descript = "Val Loss: {:.7f}".format(loss.item())
                tbar.set_description(desc=descript)
        loss_mean = loss_sum / len(tbar)

        class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy = meter.get_metrics(
        )
        print(
            "Class_0_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f | \n"
            "Class_1_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f |"
            %
            (class_accuracy[0], class_pos_accuracy[0], class_neg_accuracy[0],
             class_accuracy[1], class_pos_accuracy[1], class_neg_accuracy[1]))
        # print("Class_0_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f | \n"
        #       "Class_1_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f | \n"
        #       "Class_2_accuracy: %0.4f | Positive accuracy: %0.4f | Negative accuracy: %0.4f |" %
        #       (class_accuracy[0], class_pos_accuracy[0], class_neg_accuracy[0],
        #        class_accuracy[1], class_pos_accuracy[1], class_neg_accuracy[1],
        #        class_accuracy[2], class_pos_accuracy[2], class_neg_accuracy[2]))
        return class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_mean
Ejemplo n.º 9
0
class TrainVal(object):
    def __init__(self, config, num_query, num_classes, num_valid_classes, fold, train_triplet=False):
        """

        :param config: 配置参数
        :param num_query: 该fold查询集的数量;类型为int
        :param num_classes: 该fold训练集的类别数;类型为int
        :param num_valid_classes: 该fold验证集的类别数;类型为int
        :param fold: 训练的哪一折;类型为int
        :param train_triplet: 是否只训练triplet损失;类型为bool
        """
        self.num_query = num_query
        self.num_classes = num_classes
        self.fold = fold

        self.model_name = config.model_name
        self.last_stride = config.last_stride
        self.dist = config.dist
        self.cython = config.cython
        self.num_gpus = torch.cuda.device_count()
        print('Using {} GPUS'.format(self.num_gpus))
        print('TRAIN_VALID_RATIO: {}'.format(self.num_classes/num_valid_classes))
        print('NUM_CLASS: {}'.format(self.num_classes))
        if self.cython:
            print('USE CYTHON TO EVAL!')
        print('USE LOSS: {}'.format(config.selected_loss))

        # 加载模型,只要有GPU,则使用DataParallel函数,当GPU有多个GPU时,调用sync_bn函数
        self.model = get_model(self.model_name, self.num_classes, self.last_stride)
        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            if self.num_gpus > 1:
                self.model = convert_model(self.model)
            self.model = self.model.cuda()

        # 加载超参数
        self.epoch = config.epoch

        # 实例化实现各种子函数的 solver 类
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.solver = Solver(self.model, self.device)

        # 加载损失函数
        self.criterion = Loss(self.model_name, config.selected_loss, config.margin, self.num_classes)

        # 加载优化函数
        self.optim = get_optimizer(config, self.model)

        # 加载学习率衰减策略
        self.scheduler = get_scheduler(config, self.optim)

        # 创建保存权重的路径
        self.model_path = os.path.join(config.save_path, config.model_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        # 如果只训练Triplet损失
        if train_triplet:
            self.solver.load_checkpoint(os.path.join(self.model_path, '{}_fold{}_best.pth'.format(self.model_name,
                                                                                                  self.fold)))

        # 保存json文件和初始化tensorboard
        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now())
        self.writer = SummaryWriter(log_dir=os.path.join(self.model_path, TIMESTAMP))
        with codecs.open(self.model_path + '/' + TIMESTAMP + '.json', 'w', "utf-8") as json_file:
            json.dump({k: v for k, v in config._get_kwargs()}, json_file, ensure_ascii=False)

        # 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
        self.seed = int(time.time())
        seed_torch(self.seed)
        with open(self.model_path + '/' + TIMESTAMP + '.pkl', 'wb') as f:
            pickle.dump({'seed': self.seed}, f, -1)

        # 设置其他参数
        self.max_score = 0

    def train(self, train_loader, valid_loader):
        """ 完成模型的训练,保存模型与日志

        :param train_loader: 训练集的Dataloader
        :param valid_loader: 验证集的Dataloader
        :return: None
        """

        global_step = 0

        for epoch in range(self.epoch):
            epoch += 1
            self.model.train()
            images_number, epoch_corrects, index = 0, 0, 0

            tbar = tqdm.tqdm(train_loader)
            for index, (images, labels) in enumerate(tbar):
                # 网络的前向传播与反向传播
                outputs = self.solver.forward((images, labels))
                loss = self.solver.cal_loss(outputs, labels, self.criterion)
                self.solver.backword(self.optim, loss)

                images_number += images.size(0)
                epoch_corrects += self.model.module.get_classify_result(outputs, labels, self.device).sum()
                train_acc_iteration = self.model.module.get_classify_result(outputs, labels, self.device).mean() * 100

                # 保存到tensorboard,每一步存储一个
                global_step += 1
                descript = self.criterion.record_loss_iteration(self.writer.add_scalar, global_step)
                self.writer.add_scalar('TrainAccIteration', train_acc_iteration, global_step)

                descript = '[Train][epoch: {}/{}][Lr :{:.7f}][Acc: {:.2f}]'.format(epoch, self.epoch,
                                                                               self.scheduler.get_lr()[1],
                                                                               train_acc_iteration) + descript
                tbar.set_description(desc=descript)

            # 每一个epoch完毕之后,执行学习率衰减
            self.scheduler.step()

            # 写到tensorboard中
            epoch_acc = epoch_corrects / images_number * 100
            self.writer.add_scalar('TrainAccEpoch', epoch_acc, epoch)
            self.writer.add_scalar('Lr', self.scheduler.get_lr()[1], epoch)
            descript = self.criterion.record_loss_epoch(index, self.writer.add_scalar, epoch)

            # Print the log info
            print('[Finish epoch: {}/{}][Average Acc: {:.2}]'.format(epoch, self.epoch, epoch_acc) + descript)

            # 验证模型
            rank1, mAP, average_score = self.validation(valid_loader)

            if average_score > self.max_score:
                is_best = True
                self.max_score = average_score
            else:
                is_best = False

            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
                'max_score': self.max_score
            }

            self.solver.save_checkpoint(
                os.path.join(self.model_path, '{}_fold{}.pth'.format(self.model_name, self.fold)), state, is_best)
            self.writer.add_scalar('Rank1', rank1, epoch)
            self.writer.add_scalar('MAP', mAP, epoch)
            self.writer.add_scalar('AverageScore', average_score, epoch)

    def validation(self, valid_loader):
        """ 完成模型的验证过程

        :param valid_loader: 验证集的Dataloader
        :return rank1: rank1得分;类型为float
        :return mAP: 平均检索精度;类型为float
        :return average_score: 平均得分;类型为float
        """
        self.model.eval()
        tbar = tqdm.tqdm(valid_loader)
        features_all, labels_all = [], []
        with torch.no_grad():
            for i, (images, labels, paths) in enumerate(tbar):
                # 完成网络的前向传播
                # features = self.solver.forward((images, labels))[-1]
                features = self.solver.tta((images, labels))
                features_all.append(features.detach().cpu())
                labels_all.append(labels)

        features_all = torch.cat(features_all, dim=0)
        labels_all = torch.cat(labels_all, dim=0)

        query_features = features_all[:self.num_query]
        query_labels = labels_all[:self.num_query]

        gallery_features = features_all[self.num_query:]
        gallery_labels = labels_all[self.num_query:]

        if self.dist == 're_rank':
            distmat = re_rank(query_features, gallery_features)
        elif self.dist == 'cos_dist':
            distmat = cos_dist(query_features, gallery_features)
        elif self.dist == 'euclidean_dist':
            distmat = euclidean_dist(query_features, gallery_features)
        else:
            assert "Not implemented :{}".format(self.dist)

        all_rank_precison, mAP, _ = eval_func(distmat, query_labels.numpy(), gallery_labels.numpy(),
                                              use_cython=self.cython)

        rank1 = all_rank_precison[0]
        average_score = 0.5 * rank1 + 0.5 * mAP
        print('Rank1: {:.2%}, mAP {:.2%}, average score {:.2%}'.format(rank1, mAP, average_score))
        return rank1, mAP, average_score
class TrainVal():
    def __init__(self, config, fold):
        # 加载网络模型
        self.model_name = config.model_name
        self.model = ClassifyResNet(self.model_name, 4, training=True)
        if torch.cuda.is_available():
            self.model = torch.nn.DataParallel(self.model)
            self.model = self.model.cuda()

        # 加载超参数
        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch
        self.fold = fold

        # 实例化实现各种子函数的 solver 类
        self.solver = Solver(self.model)

        # 加载损失函数
        self.criterion = ClassifyLoss()

        # 创建保存权重的路径
        self.model_path = os.path.join(config.save_path, config.model_name)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        # 保存json文件和初始化tensorboard
        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S-%d}-classify".format(
            datetime.datetime.now(), fold)
        self.writer = SummaryWriter(
            log_dir=os.path.join(self.model_path, TIMESTAMP))
        with codecs.open(self.model_path + '/' + TIMESTAMP + '.json', 'w',
                         "utf-8") as json_file:
            json.dump({k: v
                       for k, v in config._get_kwargs()},
                      json_file,
                      ensure_ascii=False)

        self.max_accuracy_valid = 0
        # 设置随机种子,注意交叉验证部分划分训练集和验证集的时候,要保持种子固定
        self.seed = int(time.time())
        # self.seed = 1570421136
        seed_torch(self.seed)
        with open(self.model_path + '/' + TIMESTAMP + '.pkl', 'wb') as f:
            pickle.dump({'seed': self.seed}, f, -1)

    def train(self, train_loader, valid_loader):
        ''' 完成模型的训练,保存模型与日志
        Args:
            train_loader: 训练数据的DataLoader
            valid_loader: 验证数据的Dataloader
            fold: 当前跑的是第几折
        '''
        optimizer = optim.Adam(self.model.module.parameters(),
                               self.lr,
                               weight_decay=self.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.epoch + 10)
        global_step = 0

        for epoch in range(self.epoch):
            epoch += 1
            epoch_loss = 0
            self.model.train(True)

            tbar = tqdm.tqdm(train_loader)
            for i, (images, labels) in enumerate(tbar):
                # 网络的前向传播与反向传播
                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels, labels_predict,
                                            self.criterion)
                epoch_loss += loss.item()
                self.solver.backword(optimizer, loss)

                # 保存到tensorboard,每一步存储一个
                self.writer.add_scalar('train_loss', loss.item(),
                                       global_step + i)
                params_groups_lr = str()
                for group_ind, param_group in enumerate(
                        optimizer.param_groups):
                    params_groups_lr = params_groups_lr + 'params_group_%d' % (
                        group_ind) + ': %.12f, ' % (param_group['lr'])
                descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (
                    self.fold, loss.item(), params_groups_lr)
                tbar.set_description(desc=descript)

            # 每一个epoch完毕之后,执行学习率衰减
            lr_scheduler.step()
            global_step += len(train_loader)

            # Print the log info
            print('Finish Epoch [%d/%d], Average Loss: %.7f' %
                  (epoch, self.epoch, epoch_loss / len(tbar)))

            # 验证模型
            class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \
                self.validation(valid_loader)

            if accuracy > self.max_accuracy_valid:
                is_best = True
                self.max_accuracy_valid = accuracy
            else:
                is_best = False

            state = {
                'epoch': epoch,
                'state_dict': self.model.module.state_dict(),
                'max_accuracy_valid': self.max_accuracy_valid,
            }

            self.solver.save_checkpoint(
                os.path.join(
                    self.model_path,
                    '%s_classify_fold%d.pth' % (self.model_name, self.fold)),
                state, is_best)
            self.writer.add_scalar('valid_loss', loss_valid, epoch)
            self.writer.add_scalar('valid_accuracy', accuracy, epoch)
            self.writer.add_scalar('valid_class_0_accuracy', class_accuracy[0],
                                   epoch)
            self.writer.add_scalar('valid_class_1_accuracy', class_accuracy[1],
                                   epoch)
            self.writer.add_scalar('valid_class_2_accuracy', class_accuracy[2],
                                   epoch)
            self.writer.add_scalar('valid_class_3_accuracy', class_accuracy[3],
                                   epoch)

    def validation(self, valid_loader):
        ''' 完成模型的验证过程

        Args:
            valid_loader: 验证数据的Dataloader
        '''
        self.model.eval()
        meter = Meter()
        tbar = tqdm.tqdm(valid_loader)
        loss_sum = 0

        with torch.no_grad():
            for i, (images, labels) in enumerate(tbar):
                # 完成网络的前向传播
                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels, labels_predict,
                                            self.criterion)
                loss_sum += loss.item()

                meter.update(labels, labels_predict.cpu())

                descript = "Val Loss: {:.7f}".format(loss.item())
                tbar.set_description(desc=descript)
        loss_mean = loss_sum / len(tbar)

        class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy = meter.get_metrics(
        )
        print(
            "Class_0_accuracy: %0.4f | Class_1_accuracy: %0.4f | Class_2_accuracy: %0.4f | Class_3_accuracy: %0.4f | "
            "Negative accuracy: %0.4f | positive accuracy: %0.4f | accuracy: %0.4f"
            % (class_accuracy[0], class_accuracy[1], class_accuracy[2],
               class_accuracy[3], neg_accuracy, pos_accuracy, accuracy))
        return class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_mean
Ejemplo n.º 11
0
class TrainVal():
    def __init__(self, config):
        self.model = models.shufflenet_v2_x1_0(pretrained=True)

        # # freeze model parameters
        # for param in self.model.parameters():
        #     param.requires_grad = False

        self.model.fc = nn.Sequential(nn.Linear(1024, config.class_num),
                                      nn.Sigmoid())
        # for param in self.model.fc.parameters():
        #     param.requires_grad = True

        # # model check
        # print(self.model)
        # for name, param in self.model.named_parameters():
        #     if param.requires_grad:
        #         print("requires_grad: True ", name)
        #     else:
        #         print("requires_grad: False ", name)

        self.device = torch.device("cpu")
        if torch.cuda.is_available():
            self.device = torch.device("cuda:%i" % config.device[0])
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=config.device)
        self.model = self.model.to(self.device)

        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.epoch = config.epoch
        self.splits = config.n_splits
        self.root = config.root

        self.solver = Solver(self.model, self.device)

        self.criterion = nn.CrossEntropyLoss()

        self.TIME = "{0:%Y-%m-%dT%H-%M-%S}-classify".format(
            datetime.datetime.now())
        self.model_path = os.path.join(config.root, config.save_path,
                                       config.model_name, self.TIME)
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        self.max_accuracy_valid = 0
        self.seed = int(time.time())
        # self.seed = 1570421136
        seed_torch(self.seed)

        self.train_transform = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomCrop(224),
            transforms.RandomRotation(degrees=(-40, 40)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        self.test_transform = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor()])

    def train(self, create_data=False):
        if create_data:
            df = pd.read_csv(os.path.join(self.root,
                                          'train info_cwt_coarse.csv'),
                             header=None)
            labels_1dim = np.argmax(np.array(df), axis=1)
            print('<' * 20 + ' Start creating datasets ' + '>' * 20)
            skf = StratifiedKFold(n_splits=self.splits,
                                  shuffle=True,
                                  random_state=55)
            for idx, [train_df_index, val_df_index
                      ] in tqdm(enumerate(skf.split(df, labels_1dim), 1)):
                for i in train_df_index:
                    try:
                        shutil.copy(
                            os.path.join(config.root,
                                         'Ni-coarse-cwt/%s.jpg' % (i + 1)),
                            os.path.join(
                                '/home/Yuanbincheng/project/dislocation_cls/2/4cls',
                                'train_%d/%d/%d.jpg' %
                                (idx, labels_1dim[i], i + 1)))
                    except FileNotFoundError:
                        try:
                            os.mkdir(
                                os.path.join(
                                    '/home/Yuanbincheng/project/dislocation_cls/2/4cls',
                                    'train_%d' % idx))
                        except (FileNotFoundError, FileExistsError):
                            os.mkdir(
                                os.path.join(
                                    '/home/Yuanbincheng/project/dislocation_cls/2/4cls',
                                    'train_%d/%d' % (idx, labels_1dim[i])))
                for i in val_df_index:
                    try:
                        shutil.copy(
                            os.path.join(config.root,
                                         'Ni-coarse-cwt/%s.jpg' % (i + 1)),
                            os.path.join(
                                '/home/Yuanbincheng/project/dislocation_cls/2/4cls',
                                'test_%d/%d/%d.jpg' %
                                (idx, labels_1dim[i], i + 1)))
                    except FileNotFoundError:
                        try:
                            os.mkdir(
                                os.path.join(
                                    '/home/Yuanbincheng/project/dislocation_cls/2/4cls',
                                    'test_%d' % idx))
                        except (FileNotFoundError, FileExistsError):
                            os.mkdir(
                                os.path.join(
                                    '/home/Yuanbincheng/project/dislocation_cls/2/4cls',
                                    'test_%d/%d' % (idx, labels_1dim[i])))
            print('<' * 20 + ' Finish creating datasets ' + '>' * 20)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      self.model.module.parameters()),
                               self.lr,
                               weight_decay=self.weight_decay)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 25, gamma=0.99)
        global_step, global_threshold, global_threshold_pop1, global_threshold_pop2, global_threshold_pop3 = 1, 1, 1, 1, 1

        for fold_index in range(self.splits):
            train_dataset = torchvision.datasets.ImageFolder(
                root='4cls/train_%d/' % (fold_index + 1),
                transform=self.train_transform)
            train_loader = DataLoader(train_dataset,
                                      batch_size=config.batch_size,
                                      shuffle=True,
                                      num_workers=config.num_workers)
            val_dataset = torchvision.datasets.ImageFolder(
                root='4cls/test_%d/' % (fold_index + 1),
                transform=self.test_transform)
            val_loader = DataLoader(val_dataset,
                                    batch_size=config.batch_size,
                                    shuffle=False,
                                    num_workers=config.num_workers)
            self.model.train()

            TIMESTAMP = '-fold'.join([self.TIME, str(fold_index)])
            self.writer = SummaryWriter(
                log_dir=os.path.join(self.model_path, TIMESTAMP))
            with codecs.open(
                    os.path.join(self.model_path, TIMESTAMP, TIMESTAMP) +
                    '.json', 'w', "utf-8") as json_file:
                json.dump({k: v
                           for k, v in config._get_kwargs()},
                          json_file,
                          ensure_ascii=False)

            with open(
                    os.path.join(self.model_path, TIMESTAMP, TIMESTAMP) +
                    '.pkl', 'wb') as f:
                pickle.dump({'seed': self.seed}, f, -1)

            for epoch in range(1, self.epoch + 1):
                epoch += self.epoch * fold_index
                epoch_loss, num_correct, num_pred = 0, 0, 0

                tbar = tqdm(train_loader)
                for i, (images, labels) in enumerate(tbar):
                    labels_predict = self.solver.forward(images)
                    loss = self.solver.cal_loss(labels, labels_predict,
                                                self.criterion)
                    epoch_loss += loss.item()
                    self.solver.backword(optimizer, loss)

                    # tmp = (labels_predict > 0.2).float()

                    labels_predictIdx, labels_predictMax = torch.max(
                        labels_predict,
                        1)[1].cpu(), torch.max(labels_predict, 1)[0].cpu()
                    correct_idx = labels_predictIdx == labels
                    num_correct += correct_idx.sum().item()
                    num_pred += labels_predictIdx.size(0)
                    # for p, t in zip(labels_predictMax[correct_idx], labels[correct_idx]):
                    for p in labels_predict.cpu()[correct_idx]:
                        self.writer.add_scalar('threshold_pop1', p[0].item(),
                                               global_threshold)
                        self.writer.add_scalar('threshold_pop2', p[1].item(),
                                               global_threshold)
                        self.writer.add_scalar('threshold_pop3', p[2].item(),
                                               global_threshold)
                        self.writer.add_scalar('threshold_pop4', p[3].item(),
                                               global_threshold)
                        global_threshold += 1
                        # if t == 0:
                        #     self.writer.add_scalar('threshold_pop1', p.item(), global_threshold_pop1)
                        #     global_threshold_pop1 += 1
                        # elif t == 1:
                        #     self.writer.add_scalar('threshold_pop2', p.item(), global_threshold_pop2)
                        #     global_threshold_pop2 += 1
                        # elif t == 2:
                        #     self.writer.add_scalar('threshold_pop3', p.item(), global_threshold_pop3)
                        #     global_threshold_pop3 += 1

                    self.writer.add_scalar('train_loss', loss.item(),
                                           global_step + i)
                    params_groups_lr = str()
                    for group_ind, param_group in enumerate(
                            optimizer.param_groups):
                        params_groups_lr = params_groups_lr + 'params_group_%d' % (
                            group_ind) + ': %.12f, ' % (param_group['lr'])
                    descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (
                        fold_index, loss.item(), params_groups_lr)
                    tbar.set_description(desc=descript)

                lr_scheduler.step()
                global_step += len(train_loader)
                precision, recall, f1, val_loss, val_accuracy = self.validation(
                    val_loader)
                print(
                    'Finish Epoch [%d/%d] | Average training Loss: %.7f | Training accuracy: %.4f | Average validation Loss: %.7f | Validation accuracy: %.4f |'
                    % (epoch, self.epoch * config.n_splits,
                       epoch_loss / len(tbar), num_correct / num_pred,
                       val_loss, val_accuracy))

                if val_accuracy > self.max_accuracy_valid:
                    is_best = True
                    self.max_accuracy_valid = val_accuracy
                else:
                    is_best = False

                state = {
                    'epoch': epoch,
                    'state_dict': self.model.module.state_dict(),
                    'max_accuracy_valid': self.max_accuracy_valid,
                }

                self.solver.save_checkpoint(
                    os.path.join(self.model_path, TIMESTAMP,
                                 TIMESTAMP + '.pth'), state, is_best,
                    self.max_accuracy_valid)
                self.writer.add_scalar('valid_loss', val_loss, epoch)
                self.writer.add_scalar('valid_accuracy', val_accuracy, epoch)
                self.writer.add_scalar('valid_class_1_f1', f1[0], epoch)
                self.writer.add_scalar('valid_class_2_f1', f1[1], epoch)
                self.writer.add_scalar('valid_class_3_f1', f1[2], epoch)
                self.writer.add_scalar('valid_class_4_f1', f1[3], epoch)

    def validation(self, valid_loader):
        self.model.eval()
        tbar = tqdm(valid_loader)
        loss_sum, num_correct, num_pred = 0, 0, 0
        y_true, y_pre = [], []

        with torch.no_grad():
            for i, (images, labels) in enumerate(tbar):
                labels_predict = self.solver.forward(images)
                loss = self.solver.cal_loss(labels, labels_predict,
                                            self.criterion)
                loss_sum += loss.item()

                # tmp = (labels_predict > 0.2).float()
                labels_predictIdx = torch.max(labels_predict, 1)[1].cpu()
                num_correct += (labels_predictIdx == labels).sum().item()
                num_pred += labels_predictIdx.size(0)
                y_true.extend(labels.numpy().tolist())
                y_pre.extend(labels_predictIdx.numpy().tolist())

                descript = "Val Loss: {:.7f}".format(loss.item())
                tbar.set_description(desc=descript)
        loss_mean = loss_sum / len(tbar)
        res = confusion_matrix(y_true, y_pre)
        precision = np.array([
            res[i][i] / np.sum(res, axis=0)[i] for i in range(config.class_num)
        ])
        recall = np.array([
            res[i][i] / np.sum(res, axis=1)[i] for i in range(config.class_num)
        ])
        f1 = 2 * precision * recall / (precision + recall)
        for idx, [p, r, f] in enumerate(zip(precision, recall, f1)):
            print(
                "Class_%d_precision: %0.4f | Recall: %0.4f | F1-score: %0.4f |"
                % (idx, p, r, f))
        return precision, recall, f1, loss_mean, num_correct / num_pred