Beispiel #1
0
def generate(**kwargs):
    
    # step: configure
    opt._parse(**kwargs)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    vis = Visualizer(env=opt.env)
    
    # step: data
    # 这里加载数据的方式和train的略微不一样,
    # 这个问题其实从第七章就应该注意到,因为gen和val的区别主要在与data上,val的data和train的data是差不多的,但是gen的data一般只是一个
    # train直接获取了dataloader,但是这里
    # 这里的data有两个,一个是字, 一个是图片
    # word

# train he generate dui shuju de chuli ,s shifou keyi fangzai yiqi ????????????????????????????
    data = t.load(opt.caption_data_path, map_location = lambda s, _:s)
    word2ix, ix2word, end = data['word2ix'], data['ix2word'], data['end']
     
    # picture
    # 这里的代码和feature_extract中的代码有很大的相似性。但是离main已经隔了一层了,所以重新写吧
    # 
    transforms = T.Compose([
        T.Resize(opt.scale_size),
        T.CenterCrop(opt.img_size),
        T.ToTensor(),
        normalize
    ])
    
    img = Image.open(opt.test_img).convert('RGB')
    img = transforms(img).unsqueeze(0)
    img = img.to(device)
    
    # step: model: FeatureExtractModel:resnet50 CaptionModel:caption_model
    resnet50 = FeatureExtractModel()
    resnet50.to(device)
    resnet50.eval()

    
    caption_model = CaptionModel(opt, len(ix2word))
    caption_model.load(opt.model_ckpt_g)
    caption_model.to(device)
    caption_model.eval()
    
    # step: generate
    img_feats = resnet50(img) # 1*2048
    img_feats = img_feats.data[0] # 2048
    eos_id = word2ix[end]
    
    cap_sentences, cap_scores = caption_model.generate(img = img_feats,eos_id = eos_id) # 
    cap_sentences = [ ''.join([ix2word[ix.item()] for ix in sentence]) for sentence in cap_sentences]
    
    # vis.img('image', img)
    info = '<br>'.join(cap_sentences)
    vis.log(u'generate caption', info, False)
    
    return(cap_sentences,cap_scores)
def train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)

    # prepare data
    train_data = Vertebrae_Dataset(config.data_root,
                                   config.train_paths,
                                   phase='train',
                                   balance=config.data_balance)
    val_data = Vertebrae_Dataset(config.data_root,
                                 config.test_paths,
                                 phase='val',
                                 balance=config.data_balance)
    # train_data = FrameDiff_Dataset(config.data_root, config.train_paths, phase='train', balance=config.data_balance)
    # val_data = FrameDiff_Dataset(config.data_root, config.test_paths, phase='val', balance=config.data_balance)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    # model = ResNet34(num_classes=config.num_classes)
    # model = DenseNet121(num_classes=config.num_classes)
    # model = CheXPre_DenseNet121(num_classes=config.num_classes)
    # model = MultiResDenseNet121(num_classes=config.num_classes)
    # model = Vgg19(num_classes=config.num_classes)
    model = MultiResVgg19(num_classes=config.num_classes)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    model.train()

    # criterion and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # metric
    softmax = functional.softmax
    loss_meter = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)
    previous_loss = 100
    previous_acc = 0

    # train
    if config.parallel:
        if not os.path.exists(
                os.path.join('checkpoints', model.module.model_name)):
            os.mkdir(os.path.join('checkpoints', model.module.model_name))
    else:
        if not os.path.exists(os.path.join('checkpoints', model.model_name)):
            os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        loss_meter.reset()
        train_cm.reset()

        # train
        for i, (image, label, image_path) in tqdm(enumerate(train_dataloader)):
            # prepare input
            img = Variable(image)
            target = Variable(label)
            if config.use_gpu:
                img = img.cuda()
                target = target.cuda()

            # go through the model
            score = model(img)

            # backpropagate
            optimizer.zero_grad()
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.data[0])
            train_cm.add(softmax(score, dim=1).data, target.data)

            if i % config.print_freq == config.print_freq - 1:
                vis.plot('loss', loss_meter.value()[0])
                print('loss', loss_meter.value()[0])

        # print result
        train_accuracy = 100. * sum(
            [train_cm.value()[c][c]
             for c in range(config.num_classes)]) / train_cm.value().sum()
        val_cm, val_accuracy, val_loss = val(model, val_dataloader)

        if val_accuracy > previous_acc:
            if config.parallel:
                if config.save_model_name:
                    model.save(
                        os.path.join('checkpoints', model.module.model_name,
                                     config.save_model_name))
                else:
                    model.save(
                        os.path.join(
                            'checkpoints', model.module.model_name,
                            model.module.model_name + '_best_model.pth'))
            else:
                if config.save_model_name:
                    model.save(
                        os.path.join('checkpoints', model.model_name,
                                     config.save_model_name))
                else:
                    model.save(
                        os.path.join('checkpoints', model.model_name,
                                     model.model_name + '_best_model.pth'))
            previous_acc = val_accuracy

        vis.plot_many({
            'train_accuracy': train_accuracy,
            'val_accuracy': val_accuracy
        })
        vis.log(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
                epoch=epoch + 1,
                total_epoch=config.max_epoch,
                lr=lr,
                loss=loss_meter.value()[0]))
        vis.log('train_cm:')
        vis.log(train_cm.value())
        vis.log('val_cm')
        vis.log(val_cm.value())
        print('train_accuracy:', train_accuracy, 'val_accuracy:', val_accuracy)
        print("epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
            epoch=epoch + 1,
            total_epoch=config.max_epoch,
            lr=lr,
            loss=loss_meter.value()[0]))
        print('train_cm:')
        print(train_cm.value())
        print('val_cm:')
        print(val_cm.value())

        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter.value()[0]
Beispiel #3
0
def train(**kwargs):
    """
    训练
    """
    # 根据命令行参数更新配置
    opt._parse(kwargs)
    vis = Visualizer(opt.env, port=opt.vis_port)

    # step1: configure model 模型
    model = getattr(models, opt.model)()  # 最后的()不要忘
    if opt.load_model_path:
        model.load(opt.load_model_path)
    model.to(opt.device)  # 这一行和书中相比,改过

    # step2: data  数据
    train_dataset = BatteryCap(opt.train_data_root, train=True)  # 训练集
    train_dataloader = DataLoader(train_dataset,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)

    val_dataset = BatteryCap(opt.train_data_root, train=False)  # 交叉验证集
    val_dataloader = DataLoader(val_dataset,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)

    # step3: criterion and optimizer   目标函数和优化器
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = model.get_optimizer(lr, opt.weight_decay)
    # step4: meters  统计指标:平滑处理之后的损失,还有混淆矩阵
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e10

    # train  训练
    for epoch in range(opt.max_epoch):

        loss_meter.reset()
        confusion_matrix.reset()

        for ii, (data, label) in tqdm(enumerate(train_dataloader)):

            # train model 训练模型参数
            input_batch = data.to(opt.device)
            label_batch = label.to(opt.device)

            optimizer.zero_grad()  # 梯度清零
            score = model(input_batch)
            print("网络输出的:", score)
            print("-------------------------------")
            print("label", label)
            print("-------------------------------")
            print("softmax后:", t.nn.functional.softmax(score.detach(), dim=1).detach().tolist())
            print("-------------------------------")
            loss = criterion(score, label_batch)
            loss.backward()  # 反向传播
            optimizer.step()  # 优化

            # meters update and visualize  更新统计指标及可视化
            loss_meter.add(loss.item())

            # detach 一下更安全保险
            confusion_matrix.add(score.detach(), label_batch.detach())

            if (ii + 1) % opt.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])  # 先不可视化了!!!
                print('   loss: ', loss_meter.value()[0])

                # 如果需要的话,进入debug模式
                if os.path.exists(opt.debug_file):
                    import ipdb;
                    ipdb.set_trace()
        model.save()

        # validate and visualize  计算验证集上的指标及可视化
        val_cm, val_accuracy = val(model, val_dataloader)
        vis.plot('val_accuracy', val_accuracy)
        vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
            epoch=epoch, loss=loss_meter.value()[0], val_cm=str(val_cm.value()),
            train_cm=str(confusion_matrix.value()), lr=lr))

        # update learning rate  如果损失不再下降,则降低学习率
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
            # 第二种降低学习率的方法:不会有moment等信息的丢失
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]

        print('第', str(epoch), '个迭代已结束')
        print("验证集准确率为: ", str(val_accuracy))
        print('---' * 50)
Beispiel #4
0
def train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)

    train_roots = [
        os.path.join(config.data_root, 'Features_Normal'),
        os.path.join(config.data_root, 'Features_Horizontal'),
        os.path.join(config.data_root, 'Features_Vertical'),
        os.path.join(config.data_root, 'Features_Horizontal_Vertical')
    ]
    val_roots = [os.path.join(config.data_root, 'Features')]

    train_data = Feature_Dataset(train_roots,
                                 config.train_paths,
                                 phase='train',
                                 balance=config.data_balance)
    val_data = Feature_Dataset(val_roots,
                               config.test_paths,
                               phase='val',
                               balance=config.data_balance)
    print('Training Feature Lists:', train_data.__len__(),
          'Validation Feature Lists:', val_data.__len__())

    train_dataloader = DataLoader(train_data,
                                  batch_size=1,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=1,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    model = BiLSTM_CRF(tag_to_ix=tag_to_ix,
                       embedding_dim=EMBEDDING_DIM,
                       hidden_dim=HIDDEN_DIM,
                       num_layers=NUM_LAYERS)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()

    model.train()

    # criterion and optimizer
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    # metric
    loss_meter = meter.AverageValueMeter()
    previous_loss = 100000
    previous_acc = 0

    # train
    if not os.path.exists(os.path.join('checkpoints', model.model_name)):
        os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        loss_meter.reset()
        train_cm = [[0] * 3, [0] * 3, [0] * 3]
        count = 0

        # train
        for i, (features, labels,
                feature_paths) in tqdm(enumerate(train_dataloader)):
            # prepare input
            target = torch.LongTensor([tag_to_ix[t[0]] for t in labels])

            feat = Variable(features.squeeze())
            # target = Variable(target)
            if config.use_gpu:
                feat = feat.cuda()
                # target = target.cuda()

            model.zero_grad()

            try:
                neg_log_likelihood = model.neg_log_likelihood(feat, target)
            except NameError:
                count += 1
                continue

            neg_log_likelihood.backward()
            optimizer.step()

            loss_meter.add(neg_log_likelihood.data[0])
            result = model(feat)
            for t, r in zip(target, result[1]):
                train_cm[t][r] += 1

            if i % config.print_freq == config.print_freq - 1:
                vis.plot('loss', loss_meter.value()[0])
                print('loss', loss_meter.value()[0])

        train_accuracy = 100. * sum(
            [train_cm[c][c]
             for c in range(config.num_classes)]) / np.sum(train_cm)
        val_cm, val_accuracy, val_loss = val(model, val_dataloader)

        if val_accuracy > previous_acc:
            if config.save_model_name:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 config.save_model_name))
            else:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 model.model_name + '_best_model.pth'))
            previous_acc = val_accuracy

        vis.plot_many({
            'train_accuracy': train_accuracy,
            'val_accuracy': val_accuracy
        })
        vis.log(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
                epoch=epoch + 1,
                total_epoch=config.max_epoch,
                lr=lr,
                loss=loss_meter.value()[0]))
        vis.log('train_cm:')
        vis.log(train_cm)
        vis.log('val_cm')
        vis.log(val_cm)
        print('train_accuracy:', train_accuracy, 'val_accuracy:', val_accuracy)
        print("epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
            epoch=epoch + 1,
            total_epoch=config.max_epoch,
            lr=lr,
            loss=loss_meter.value()[0]))
        print('train_cm:')
        print(train_cm)
        print('val_cm:')
        print(val_cm)
        print('Num of NameError:', count)

        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter.value()[0]
Beispiel #5
0
def train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data = SlideWindowDataset(config.train_paths,
                                    phase='train',
                                    useRGB=config.useRGB,
                                    usetrans=config.usetrans,
                                    balance=config.data_balance)
    val_data = SlideWindowDataset(config.test_paths,
                                  phase='val',
                                  useRGB=config.useRGB,
                                  usetrans=config.usetrans,
                                  balance=False)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist)

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    # model = AlexNet(num_classes=config.num_classes)
    # model = Vgg16(num_classes=config.num_classes)
    # model = Modified_Vgg16(num_classes=config.num_classes)
    # model = ResNet18(num_classes=config.num_classes)
    model = ResNet50(num_classes=config.num_classes)
    # model = DenseNet121(num_classes=config.num_classes)
    # model = ShallowNet(num_classes=config.num_classes)
    # model = Customed_ShallowNet(num_classes=config.num_classes)

    # model = Modified_AGVgg16(num_classes=2)
    # model = AGResNet18(num_classes=2)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1, 1])
    # weight = torch.Tensor([dist['1']/(dist['0']+dist['1']), dist['0']/(dist['0']+dist['1'])])  # weight需要将二者反过来,多于二分类可以取倒数
    # weight = torch.Tensor([1, 3.5])
    # weight = torch.Tensor([1, 5])
    weight = torch.Tensor([1, 7])

    vis.log(f'loss weight: {weight}')
    print('loss weight:', weight)
    weight = weight.cuda()
    criterion = torch.nn.CrossEntropyLoss(weight=weight)
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    softmax = functional.softmax
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)

    # ====================================== Saving and Recording Configuration =================================
    previous_auc = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    process_record = {
        'epoch_loss': [],
        'train_avg_se': [],
        'train_se_0': [],
        'train_se_1': [],
        'val_avg_se': [],
        'val_se_0': [],
        'val_se_1': [],
        'AUC': []
    }  # 用于记录实验过程中的曲线,便于画曲线图

    # ================================================== Training ===============================================
    for epoch in range(config.max_epoch):
        print(
            f"epoch: [{epoch+1}/{config.max_epoch}] {config.save_model_name[:-4]} =================================="
        )
        train_cm.reset()
        epoch_loss.reset()

        # ****************************************** train ****************************************
        model.train()
        for i, (image, label, image_path) in tqdm(enumerate(train_dataloader)):
            loss_meter.reset()

            # ------------------------------------ prepare input ------------------------------------
            if config.use_gpu:
                image = image.cuda()
                label = label.cuda()

            # ---------------------------------- go through the model --------------------------------
            score = model(image)

            # ----------------------------------- backpropagate -------------------------------------
            optimizer.zero_grad()
            loss = criterion(score, label)
            loss.backward()
            optimizer.step()

            # ------------------------------------ record loss ------------------------------------
            loss_meter.add(loss.item())
            epoch_loss.add(loss.item())
            train_cm.add(softmax(score, dim=1).detach(), label.detach())

            if (i + 1) % config.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])

        train_se = [
            100. * train_cm.value()[0][0] /
            (train_cm.value()[0][0] + train_cm.value()[0][1]),
            100. * train_cm.value()[1][1] /
            (train_cm.value()[1][0] + train_cm.value()[1][1])
        ]

        # *************************************** validate ***************************************
        model.eval()
        if (epoch + 1) % 1 == 0:
            Best_T, val_cm, val_spse, val_accuracy, AUC = val(
                model, val_dataloader)

            # ------------------------------------ save model ------------------------------------
            if AUC > previous_auc and epoch + 1 > 5:
                if config.parallel:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.module.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                else:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                previous_auc = AUC
                save_epoch = epoch + 1

            # ---------------------------------- recond and print ---------------------------------
            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['train_avg_se'].append(np.average(train_se))
            process_record['train_se_0'].append(train_se[0])
            process_record['train_se_1'].append(train_se[1])
            process_record['val_avg_se'].append(np.average(val_spse))
            process_record['val_se_0'].append(val_spse[0])
            process_record['val_se_1'].append(val_spse[1])
            process_record['AUC'].append(AUC)

            vis.plot_many({
                'epoch_loss': epoch_loss.value()[0],
                'train_avg_se': np.average(train_se),
                'train_se_0': train_se[0],
                'train_se_1': train_se[1],
                'val_avg_se': np.average(val_spse),
                'val_se_0': val_spse[0],
                'val_se_1': val_spse[1],
                'AUC': AUC
            })
            vis.log(
                f"epoch: [{epoch+1}/{config.max_epoch}] ========================================="
            )
            vis.log(
                f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}"
            )
            vis.log(
                f"train_avg_se: {round(np.average(train_se), 4)}, train_se_0: {round(train_se[0], 4)}, train_se_1: {round(train_se[1], 4)}"
            )
            vis.log(
                f"val_avg_se: {round(sum(val_spse)/len(val_spse), 4)}, val_se_0: {round(val_spse[0], 4)}, val_se_1: {round(val_spse[1], 4)}"
            )
            vis.log(f"AUC: {AUC}")
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'Best Threshold: {Best_T}')
            vis.log(f'val_cm: {val_cm}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                  round(epoch_loss.value()[0], 5))
            print('train_avg_se:', round(np.average(train_se), 4),
                  'train_se_0:', round(train_se[0], 4), 'train_se_1:',
                  round(train_se[1], 4))
            print('val_avg_se:', round(np.average(val_spse), 4), 'val_se_0:',
                  round(val_spse[0], 4), 'val_se_1:', round(val_spse[1], 4))
            print('AUC:', AUC)
            print('train_cm:')
            print(train_cm.value())
            print('Best Threshold:', Best_T, 'val_cm:')
            print(val_cm)

            # ------------------------------------ save record ------------------------------------
            if os.path.exists(
                    os.path.join('checkpoints', save_model_dir,
                                 save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir,
                                             save_model_name.split('.')[0],
                                             'process_record.json'),
                           content=process_record)

        # if (epoch+1) % 5 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)
Beispiel #6
0
def train(**kwargs):
    # step: configure
    opt._parse(**kwargs)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    if opt.env:
        vis = Visualizer(env=opt.env)
    # step: data
    data, word2ix, ix2word = get_data(
        opt)  # data numpy二维数组, word2ix, ix2word 字典
    # from_numpy共享内存,一个数字的变化也会影响另一个,但是t.tensor不会共享内存,两个基本完全独立
    data = t.from_numpy(data)
    # 这里是因为鸭子类型,还需要考虑考虑
    dataloader = DataLoader(data,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers)

    # step: model && criterion && meter && optimizer

    model = PoetryModel(len(word2ix), opt.embedding_dim, opt.hidden_dim,
                        opt.num_layers)
    if opt.model_path:
        model.load_state_dict(t.load(opt.model_path))
    model.to(device)

    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)

    criterion = nn.CrossEntropyLoss()

    loss_meter = meter.AverageValueMeter()

    # step: train
    for epoch in range(opt.epoch):
        loss_meter.reset()
        for ii, x in tqdm(enumerate(dataloader)):
            # embedding层的输入必须是LongTensor型
            # 现在x是tensor (batchsize*seq_len),LSTM的输入需要是(seq_len, batch_size, embedding_dim)
            # 矩阵的转置会导致存储空间不连续, 需要调用.contiguous()方法使其连续
            x = x.long().transpose(1, 0).contiguous()
            x = x.to(device)
            optimizer.zero_grad()
            input, target = x[:-1, :], x[
                1:, :]  # target :(seq_len, batch_size)
            # 运行的时候这里要看一下大小
            output, _ = model(
                input)  # output size (seq_len*batch_size, vocab_size)
            loss = criterion(output, target.view(-1))  # 交叉熵损失的定义
            # 这里需要重新想明白,这个lstm是怎么个输入输出
            loss.backward()
            optimizer.step()
            # 这里的loss是一个只有一个数字的tensor,
            # loss.item()返回一个新的Python的对应的类型,不共享内存,改变不会影响彼此
            # 经师兄提醒,才注意到计算评价loss的时候,需要想办法去除掉loss.backward等特性,避免时间长了占内存,
            # 这里没有loss.data,loss.data也会有backward等特性,还是属于tensor系列,突然感觉自己还是遗漏了好多点。
            # 现在只能一边做,一边查缺补漏,看到哪里,学到哪里,对于一些细节要经常去查。
            # 这里需要重新解释一下,每一个tensor代表一个计算图,如果直接使用tensor进行累加的话,会造total_loss的计算图不断累加的
            # 有点乱了,我去,不管了,先记住,对于损失累加,我们只使用loss.item,这是种完全截断计算图的方法
            loss_meter.add(loss.item())
            # step: visualize and validate
            if (ii + 1) % opt.print_freq == 0 and opt.env:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss', loss_meter.value()[0])
                # 诗歌原文
                # x tensor size (seq_len, batch_size)
                # 二重列表生成式, poetrys:[['我''你'],[..]]

                poetries = [[ix2word[word_] for word_ in x[:, j_]]
                            for j_ in range(x.shape[1])]
                # origin_poetries =[]
                # origin_poetries_tmp = []
                # #  range(data_.shape[1]
                # for j_ in range(3):
                #     for word_ in x[:,j_].tolist():
                #         origin_poetries_tmp.append(ix2word[word_])
                #     origin_poetries.append(origin_poetries_tmp)
                #     origin_poetries_tmp = []
                vis.log('<br/>'.join(
                    [''.join(origin_poe) for origin_poe in origin_poetries]),
                        win=u'origin_poem')

                # 生成的诗歌
                gen_poetris = []
                # 分别以这几个字作为诗歌的第一个字,生成8首诗 验证模型
                # gen_poetris 二重list,每一个list都是一首诗 [['我','你'],[]]
                for word in list(u'春江花月夜凉如水'):
                    gen_poetry = generate(model, word, ix2word, word2ix)
                    gen_poetris.append(gen_poetry)
                # gen_poetris 二重列表,与poetries一致
                vis.log('<br/>'.join(
                    [''.join(gen_poe) for gen_poe in gen_poetris]),
                        win=u'gen_poem')

        t.save(model.state_dict(),
               '{0}_{1}.pth'.format(opt.model_prefix, epoch))
Beispiel #7
0
def train(**kwargs):
    """
    training
    :param kwargs:
    :return:
    """
    opt.parse(kwargs)
    vis = Visualizer(opt.env)
    # step 1: model
    model = getattr(models, opt.model)()

    # ipdb.set_trace()
    # if opt.load_model_path:
    #     model.load(opt.load_model_path)
    if opt.use_gpu:
        model.cuda()

    # step 2: data
    train_data = DogCat(opt.train_data_root, train=True)
    val_data = DogCat(opt.train_data_root, train=False)
    train_data_loader = DataLoader(train_data, opt.batch_size, shuffle=True,
                                   num_workers=opt.num_workers)
    val_data_loader = DataLoader(val_data, opt.batch_size, shuffle=False,
                                 num_workers=opt.num_workers)

    # step 3: loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    learning_rate = opt.learning_rate
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=opt.weight_decay)

    # step 4: statistical indicators
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e100

    for epoch in range(opt.max_epoch):
        loss_meter.reset()
        confusion_matrix.reset()

        for ii, (data, label) in enumerate(train_data_loader):
            print("epoch: {epoch}, batch: {batch}".format(epoch=epoch, batch=ii))
            # training model parameters
            input_ = Variable(data)
            target = Variable(label)
            if opt.use_gpu:
                input_ = input_.cuda()
                target = target.cuda()
            optimizer.zero_grad()
            score = model(input_)
            loss = criterion(score, target)
            loss.backward()

            # update statistical indicators and visualization
            loss_meter.add(loss.item())
            confusion_matrix.add(score.data, target.data)

            if ii % opt.print_freq == opt.print_freq - 1:
                vis.plot('loss', loss_meter.value()[0])
                # if necessary, step into debug mode
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

        model.save()

        # calculate statistical indicators in the validation set and visualization
        val_cm, val_accuracy = val(model, val_data_loader)
        vis.plot("val_accuracy", val_accuracy)
        vis.log("epoch: {epoch}, learning_rate: {learning_rate}, loss: {loss}, train_cm: {train_cm}, val_cm: {val_cm}".format(
            epoch=epoch, learning_rate=learning_rate, loss=loss_meter.value()[0],
            train_cm=str(confusion_matrix.value()), val_cm=str(val_cm.value())))

        if loss_meter.value()[0] > previous_loss:
            learning_rate = learning_rate * opt.learning_rate_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate

        previous_loss = loss_meter.value()[0]
Beispiel #8
0
def train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data_1 = SlideWindowDataset(config.train_paths, phase='train', useRGB=config.useRGB, usetrans=config.usetrans, balance=config.data_balance)
    train_data_2 = SlideWindowDataset(config.train_paths, phase='train', useRGB=config.useRGB, usetrans=config.usetrans, balance=config.data_balance)
    val_data = SlideWindowDataset(config.test_paths, phase='val', useRGB=config.useRGB, usetrans=config.usetrans, balance=False)
    print('Training Images:', train_data_1.__len__(), 'Validation Images:', val_data.__len__())
    dist = train_data_1.dist()
    print('Train Data Distribution:', dist)

    train_dataloader_1 = DataLoader(train_data_1, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
    train_dataloader_2 = DataLoader(train_data_2, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    # model = PCResNet18(num_classes=config.num_classes)
    model = DualResNet18(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(model, device_ids=[x for x in range(config.num_of_gpu)])

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1, 1])
    # weight = torch.Tensor([dist['1']/(dist['0']+dist['1']), dist['0']/(dist['0']+dist['1'])])  # weight需要将二者反过来,多于二分类可以取倒数
    # weight = torch.Tensor([1, 3.5])
    # weight = torch.Tensor([1, 5])
    weight = torch.Tensor([1, 7])
    vis.log(f'loss weight: {weight}')
    print('loss weight:', weight)
    weight = weight.cuda()

    criterion = torch.nn.CrossEntropyLoss(weight=weight)
    MSELoss = torch.nn.MSELoss()
    sycriterion = torch.nn.CrossEntropyLoss()

    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    softmax = functional.softmax
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    mse_meter = meter.AverageValueMeter()
    epoch_mse = meter.AverageValueMeter()
    syloss_meter = meter.AverageValueMeter()
    epoch_syloss = meter.AverageValueMeter()
    total_loss_meter = meter.AverageValueMeter()
    epoch_total_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)

    # ====================================== Saving and Recording Configuration =================================
    previous_auc = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    process_record = {'epoch_loss': [],
                      'train_avg_se': [], 'train_se_0': [], 'train_se_1': [],
                      'val_avg_se': [], 'val_se_0': [], 'val_se_1': [],
                      'AUC': []}  # 用于记录实验过程中的曲线,便于画曲线图

    # ================================================== Training ===============================================
    for epoch in range(config.max_epoch):
        print(f"epoch: [{epoch+1}/{config.max_epoch}] {config.save_model_name[:-4]} ==================================")
        train_cm.reset()
        epoch_loss.reset()
        epoch_mse.reset()
        epoch_syloss.reset()
        epoch_total_loss.reset()

        # ****************************************** train ****************************************
        model.train()
        for i, (item1, item2) in tqdm(enumerate(zip(train_dataloader_1, train_dataloader_2))):
            loss_meter.reset()
            mse_meter.reset()
            syloss_meter.reset()
            total_loss_meter.reset()

            # ------------------------------------ prepare input ------------------------------------
            image1, label1, image_path1 = item1
            image2, label2, image_path2 = item2
            if config.use_gpu:
                image1 = image1.cuda()
                image2 = image2.cuda()
                label1 = label1.cuda()
                label2 = label2.cuda()

            # ---------------------------------- go through the model --------------------------------
            # score1, score2, logits1, logits2 = model(image1, image2)  # Pairwise Confusion Network
            score1, score2, score3 = model(image1, image2)  # Dual CNN

            # ----------------------------------- backpropagate -------------------------------------
            # 两支之间的feature加入L2 norm
            # optimizer.zero_grad()
            # cls_loss1 = criterion(score1, label1)
            # cls_loss2 = criterion(score2, label2)
            #
            # ch_weight = torch.where(label1 == label2, torch.Tensor([0]).cuda(), torch.Tensor([1]).cuda())
            # ch_weight = ch_weight.view(logits1.size(0), -1)
            # mse = MSELoss(logits1 * ch_weight, logits2 * ch_weight)  # 只计算不同类之间的loss,相同类的置零
            #
            # total_loss = cls_loss1 + cls_loss2 + 10 * mse
            # total_loss.backward()
            # optimizer.step()

            # 两支之间的logits加入判断是否属于同一类的loss
            optimizer.zero_grad()
            cls_loss1 = criterion(score1, label1)
            cls_loss2 = criterion(score2, label2)

            sylabel = torch.where(label1 == label2, torch.Tensor([0]).cuda(), torch.Tensor([1]).cuda()).long()
            sy_loss = sycriterion(score3, sylabel)

            total_loss = cls_loss1 + cls_loss2 + 2 * sy_loss
            total_loss.backward()
            optimizer.step()

            # ------------------------------------ record loss ------------------------------------
            loss_meter.add((cls_loss1 + cls_loss2).item())
            # mse_meter.add(mse.item())
            # syloss_meter.add(sy_loss.item())
            # total_loss_meter.add(total_loss.item())

            epoch_loss.add((cls_loss1 + cls_loss2).item())
            # epoch_mse.add(mse.item())
            epoch_syloss.add(sy_loss.item())
            epoch_total_loss.add(total_loss.item())

            train_cm.add(softmax(score1, dim=1).detach(), label1.detach())

            if (i+1) % config.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])

        train_se = [100. * train_cm.value()[0][0] / (train_cm.value()[0][0] + train_cm.value()[0][1]),
                    100. * train_cm.value()[1][1] / (train_cm.value()[1][0] + train_cm.value()[1][1])]

        # *************************************** validate ***************************************
        model.eval()
        if (epoch + 1) % 1 == 0:
            Best_T, val_cm, val_spse, val_accuracy, AUC = val(model, val_dataloader)

            # ------------------------------------ save model ------------------------------------
            if AUC > previous_auc and epoch + 1 > 5:
                if config.parallel:
                    if not os.path.exists(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0])):
                        os.makedirs(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0]))
                    model.module.save(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0], save_model_name))
                else:
                    if not os.path.exists(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0])):
                        os.makedirs(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0]))
                    model.save(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0], save_model_name))
                previous_auc = AUC
                save_epoch = epoch + 1

            # ---------------------------------- recond and print ---------------------------------
            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['train_avg_se'].append(np.average(train_se))
            process_record['train_se_0'].append(train_se[0])
            process_record['train_se_1'].append(train_se[1])
            process_record['val_avg_se'].append(np.average(val_spse))
            process_record['val_se_0'].append(val_spse[0])
            process_record['val_se_1'].append(val_spse[1])
            process_record['AUC'].append(AUC)

            # vis.plot('epoch_mse', epoch_mse.value()[0])
            vis.plot('epoch_syloss', epoch_syloss.value()[0])
            vis.plot_many({'epoch_loss': epoch_loss.value()[0], 'epoch_total_loss': epoch_total_loss.value()[0],
                           'train_avg_se': np.average(train_se), 'train_se_0': train_se[0], 'train_se_1': train_se[1],
                           'val_avg_se': np.average(val_spse), 'val_se_0': val_spse[0], 'val_se_1': val_spse[1],
                           'AUC': AUC})
            vis.log(f"epoch: [{epoch+1}/{config.max_epoch}] =========================================")
            vis.log(f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}")
            vis.log(f"train_avg_se: {round(np.average(train_se), 4)}, train_se_0: {round(train_se[0], 4)}, train_se_1: {round(train_se[1], 4)}")
            vis.log(f"val_avg_se: {round(sum(val_spse)/len(val_spse), 4)}, val_se_0: {round(val_spse[0], 4)}, val_se_1: {round(val_spse[1], 4)}")
            vis.log(f"AUC: {AUC}")
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'Best Threshold: {Best_T}')
            vis.log(f'val_cm: {val_cm}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:", round(epoch_loss.value()[0], 5))
            print('train_avg_se:', round(np.average(train_se), 4), 'train_se_0:', round(train_se[0], 4), 'train_se_1:', round(train_se[1], 4))
            print('val_avg_se:', round(np.average(val_spse), 4), 'val_se_0:', round(val_spse[0], 4), 'val_se_1:', round(val_spse[1], 4))
            print('AUC:', AUC)
            print('train_cm:')
            print(train_cm.value())
            print('Best Threshold:', Best_T, 'val_cm:')
            print(val_cm)

            # ------------------------------------ save record ------------------------------------
            if os.path.exists(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0], 'process_record.json'),
                           content=process_record)

        # if (epoch+1) % 5 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)
Beispiel #9
0
def train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # prepare data
    train_data = VB_Dataset(config.train_paths,
                            phase='train',
                            useRGB=config.useRGB,
                            usetrans=config.usetrans,
                            padding=config.padding,
                            balance=config.data_balance)
    val_data = VB_Dataset(config.test_paths,
                          phase='val',
                          useRGB=config.useRGB,
                          usetrans=config.usetrans,
                          padding=config.padding,
                          balance=False)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist, 'Val Data Distribution:',
          val_data.dist())

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    # model = ResNet18(num_classes=config.num_classes)
    # model = Vgg16(num_classes=config.num_classes)
    # model = densenet_collapse(num_classes=config.num_classes)
    model = ShallowVgg(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    # criterion and optimizer
    # weight = torch.Tensor([1/dist['0'], 1/dist['1'], 1/dist['2'], 1/dist['3']])
    # weight = torch.Tensor([1/dist['0'], 1/dist['1']])
    # weight = torch.Tensor([dist['1'], dist['0']])
    # weight = torch.Tensor([1, 10])
    # vis.log(f'loss weight: {weight}')
    # print('loss weight:', weight)
    # weight = weight.cuda()

    # criterion = torch.nn.CrossEntropyLoss()
    criterion = LabelSmoothing(size=config.num_classes, smoothing=0.1)
    # criterion = torch.nn.CrossEntropyLoss(weight=weight)
    # criterion = FocalLoss(gamma=4, alpha=None)

    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # metric
    softmax = functional.softmax
    log_softmax = functional.log_softmax
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)
    train_AUC = meter.AUCMeter()

    previous_avgse = 0
    # previous_AUC = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    # process_record = {'epoch_loss': [],  # 用于记录实验过程中的曲线,便于画曲线图
    #                   'train_avgse': [], 'train_se0': [], 'train_se1': [], 'train_se2': [], 'train_se3': [],
    #                   'val_avgse': [], 'val_se0': [], 'val_se1': [], 'val_se2': [], 'val_se3': []}
    process_record = {
        'epoch_loss': [],  # 用于记录实验过程中的曲线,便于画曲线图
        'train_avgse': [],
        'train_se0': [],
        'train_se1': [],
        'val_avgse': [],
        'val_se0': [],
        'val_se1': [],
        'train_AUC': [],
        'val_AUC': []
    }

    # train
    for epoch in range(config.max_epoch):
        print(
            f"epoch: [{epoch+1}/{config.max_epoch}] {config.save_model_name[:-4]} =================================="
        )
        epoch_loss.reset()
        train_cm.reset()
        train_AUC.reset()

        # train
        model.train()
        for i, (image, label, image_path) in tqdm(enumerate(train_dataloader)):
            loss_meter.reset()

            # prepare input
            if config.use_gpu:
                image = image.cuda()
                label = label.cuda()

            # go through the model
            score = model(image)

            # backpropagate
            optimizer.zero_grad()
            # loss = criterion(score, label)
            loss = criterion(log_softmax(score, dim=1), label)
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.item())
            epoch_loss.add(loss.item())
            train_cm.add(softmax(score, dim=1).data, label.data)
            positive_score = np.array([
                item[1]
                for item in softmax(score, dim=1).data.cpu().numpy().tolist()
            ])
            train_AUC.add(positive_score, label.data)

            if (i + 1) % config.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])

        # print result
        # train_se = [100. * train_cm.value()[0][0] / (train_cm.value()[0][0] + train_cm.value()[0][1] + train_cm.value()[0][2] + train_cm.value()[0][3]),
        #             100. * train_cm.value()[1][1] / (train_cm.value()[1][0] + train_cm.value()[1][1] + train_cm.value()[1][2] + train_cm.value()[1][3]),
        #             100. * train_cm.value()[2][2] / (train_cm.value()[2][0] + train_cm.value()[2][1] + train_cm.value()[2][2] + train_cm.value()[2][3]),
        #             100. * train_cm.value()[3][3] / (train_cm.value()[3][0] + train_cm.value()[3][1] + train_cm.value()[3][2] + train_cm.value()[3][3])]
        train_se = [
            100. * train_cm.value()[0][0] /
            (train_cm.value()[0][0] + train_cm.value()[0][1]),
            100. * train_cm.value()[1][1] /
            (train_cm.value()[1][0] + train_cm.value()[1][1])
        ]

        # validate
        model.eval()
        if (epoch + 1) % 1 == 0:
            val_cm, val_se, val_accuracy, val_AUC = val_2class(
                model, val_dataloader)

            if np.average(
                    val_se) > previous_avgse:  # 当测试集上的平均sensitivity升高时保存模型
                # if val_AUC.value()[0] > previous_AUC:  # 当测试集上的AUC升高时保存模型
                if config.parallel:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.module.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                else:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                previous_avgse = np.average(val_se)
                # previous_AUC = val_AUC.value()[0]
                save_epoch = epoch + 1

            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['train_avgse'].append(np.average(train_se))
            process_record['train_se0'].append(train_se[0])
            process_record['train_se1'].append(train_se[1])
            # process_record['train_se2'].append(train_se[2])
            # process_record['train_se3'].append(train_se[3])
            process_record['train_AUC'].append(train_AUC.value()[0])
            process_record['val_avgse'].append(np.average(val_se))
            process_record['val_se0'].append(val_se[0])
            process_record['val_se1'].append(val_se[1])
            # process_record['val_se2'].append(val_se[2])
            # process_record['val_se3'].append(val_se[3])
            process_record['val_AUC'].append(val_AUC.value()[0])

            # vis.plot_many({'epoch_loss': epoch_loss.value()[0],
            #                'train_avgse': np.average(train_se), 'train_se0': train_se[0], 'train_se1': train_se[1], 'train_se2': train_se[2], 'train_se3': train_se[3],
            #                'val_avgse': np.average(val_se), 'val_se0': val_se[0], 'val_se1': val_se[1], 'val_se2': val_se[2], 'val_se3': val_se[3]})
            # vis.log(f"epoch: [{epoch+1}/{config.max_epoch}] =========================================")
            # vis.log(f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}")
            # vis.log(f"train_avgse: {round(np.average(train_se), 4)}, train_se0: {round(train_se[0], 4)}, train_se1: {round(train_se[1], 4)}, train_se2: {round(train_se[2], 4)}, train_se3: {round(train_se[3], 4)},")
            # vis.log(f"val_avgse: {round(np.average(val_se), 4)}, val_se0: {round(val_se[0], 4)}, val_se1: {round(val_se[1], 4)}, val_se2: {round(val_se[2], 4)}, val_se3: {round(val_se[3], 4)}")
            # vis.log(f'train_cm: {train_cm.value()}')
            # vis.log(f'val_cm: {val_cm.value()}')
            # print("lr:", optimizer.param_groups[0]['lr'], "loss:", round(epoch_loss.value()[0], 5))
            # print('train_avgse:', round(np.average(train_se), 4), 'train_se0:', round(train_se[0], 4), 'train_se1:', round(train_se[1], 4), 'train_se2:', round(train_se[2], 4), 'train_se3:', round(train_se[3], 4))
            # print('val_avgse:', round(np.average(val_se), 4), 'val_se0:', round(val_se[0], 4), 'val_se1:', round(val_se[1], 4), 'val_se2:', round(val_se[2], 4), 'val_se3:', round(val_se[3], 4))
            # print('train_cm:')
            # print(train_cm.value())
            # print('val_cm:')
            # print(val_cm.value())

            vis.plot_many({
                'epoch_loss': epoch_loss.value()[0],
                'train_avgse': np.average(train_se),
                'train_se0': train_se[0],
                'train_se1': train_se[1],
                'val_avgse': np.average(val_se),
                'val_se0': val_se[0],
                'val_se1': val_se[1],
                'train_AUC': train_AUC.value()[0],
                'val_AUC': val_AUC.value()[0]
            })
            vis.log(
                f"epoch: [{epoch + 1}/{config.max_epoch}] ========================================="
            )
            vis.log(
                f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}"
            )
            vis.log(
                f"train_avgse: {round(np.average(train_se), 4)}, train_se0: {round(train_se[0], 4)}, train_se1: {round(train_se[1], 4)}"
            )
            vis.log(
                f"val_avgse: {round(np.average(val_se), 4)}, val_se0: {round(val_se[0], 4)}, val_se1: {round(val_se[1], 4)}"
            )
            vis.log(f'train_AUC: {train_AUC.value()[0]}')
            vis.log(f'val_AUC: {val_AUC.value()[0]}')
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'val_cm: {val_cm.value()}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                  round(epoch_loss.value()[0], 5))
            print('train_avgse:', round(np.average(train_se), 4), 'train_se0:',
                  round(train_se[0], 4), 'train_se1:', round(train_se[1], 4))
            print('val_avgse:', round(np.average(val_se), 4), 'val_se0:',
                  round(val_se[0], 4), 'val_se1:', round(val_se[1], 4))
            print('train_AUC:',
                  train_AUC.value()[0], 'val_AUC:',
                  val_AUC.value()[0])
            print('train_cm:')
            print(train_cm.value())
            print('val_cm:')
            print(val_cm.value())

            if os.path.exists(
                    os.path.join('checkpoints', save_model_dir,
                                 save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir,
                                             save_model_name.split('.')[0],
                                             'process_record.json'),
                           content=process_record)

        # if (epoch+1) % 5 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)
Beispiel #10
0
def train(**kwargs):
    opt = DefaultConfig()
    opt.update(**kwargs)

    vis = Visualizer(opt['model'])
    logger = Logger()

    prefix = ''
    if opt['use_double_length']: prefix += '_2'
    print prefix
    if opt['use_char']:
        logger.info('Load char data starting...')
        opt['embed_num'] = opt['char_embed_num']
        embed_mat = np.load(opt['char_embed'])
        train_title = np.load(opt['train_title_char' + prefix])
        train_desc = np.load(opt['train_desc_char' + prefix])
        train_label = np.load(opt['train_label'])
        val_title = np.load(opt['val_title_char' + prefix])
        val_desc = np.load(opt['val_desc_char' + prefix])
        val_label = np.load(opt['val_label'])
        logger.info('Load char data finished!')
    elif opt['use_word']:
        logger.info('Load word data starting...')
        opt['embed_num'] = opt['word_embed_num']
        embed_mat = np.load(opt['word_embed'])
        train_title = np.load(opt['train_title_word' + prefix])
        train_desc = np.load(opt['train_desc_word' + prefix])
        train_label = np.load(opt['train_label'])
        val_title = np.load(opt['val_title_word' + prefix])
        val_desc = np.load(opt['val_desc_word' + prefix])
        val_label = np.load(opt['val_label'])
        logger.info('Load word data finished!')
    elif opt['use_char_word']:
        logger.info('Load char-word data starting...')
        embed_mat_char = np.load(opt['char_embed'])
        embed_mat_word = np.load(opt['word_embed'])
        embed_mat = np.vstack((embed_mat_char, embed_mat_word))
        train_title = np.load(opt['train_title_char' + prefix])
        train_desc = np.load(opt['train_desc_word' + prefix])
        train_label = np.load(opt['train_label'])
        val_title = np.load(opt['val_title_char' + prefix])
        val_desc = np.load(opt['val_desc_word' + prefix])
        val_label = np.load(opt['val_label'])
        logger.info('Load char-word data finished!')
    elif opt['use_word_char']:
        logger.info('Load word-char data starting...')
        embed_mat_char = np.load(opt['char_embed'])
        embed_mat_word = np.load(opt['word_embed'])
        embed_mat = np.vstack((embed_mat_char, embed_mat_word))
        train_title = np.load(opt['train_title_word' + prefix])
        train_desc = np.load(opt['train_desc_char' + prefix])
        train_label = np.load(opt['train_label'])
        val_title = np.load(opt['val_title_word' + prefix])
        val_desc = np.load(opt['val_desc_char' + prefix])
        val_label = np.load(opt['val_label'])
        logger.info('Load word-char data finished!')

    train_dataset = Dataset(title=train_title,
                            desc=train_desc,
                            label=train_label,
                            class_num=opt['class_num'])
    train_loader = data.DataLoader(train_dataset,
                                   shuffle=True,
                                   batch_size=opt['batch_size'])
    val_dataset = Dataset(title=val_title,
                          desc=val_desc,
                          label=val_label,
                          class_num=opt['class_num'])
    val_loader = data.DataLoader(val_dataset,
                                 shuffle=False,
                                 batch_size=opt['batch_size'])

    logger.info('Using model {}'.format(opt['model']))
    Model = getattr(models, opt['model'])
    model = Model(embed_mat, opt)
    print model

    loss_weight = torch.ones(opt['class_num'])
    if opt['boost']:
        if opt['base_layer'] != 0:
            cal_res = torch.load('{}/{}/layer_{}_cal_res_3.pt'.format(
                opt['model_dir'], opt['model'], opt['base_layer']),
                                 map_location=lambda storage, loc: storage)
            logger.info('Load cal_res successful!')
            loss_weight = torch.load('{}/{}/layer_{}_loss_weight_3.pt'.format(
                opt['model_dir'], opt['model'], opt['base_layer'] + 1),
                                     map_location=lambda storage, loc: storage)
        else:
            cal_res = torch.zeros(opt['val_num'], opt['class_num'])
        print 'cur_layer:', opt['base_layer'] + 1, \
              'loss_weight:', loss_weight.mean(), loss_weight.max(), loss_weight.min(), loss_weight.std()

    if opt['use_self_loss']:
        Loss = getattr(models, opt['loss_function'])
    else:
        Loss = getattr(nn, opt['loss_function'])

    if opt['load']:
        if opt.get('load_name', None) is None:
            model = load_model(model,
                               model_dir=opt['model_dir'],
                               model_name=opt['model'])
        else:
            model = load_model(model, model_dir=opt['model_dir'], model_name=opt['model'], \
                              name=opt['load_name'])

    if opt['cuda'] and opt['device'] != None:
        torch.cuda.set_device(opt['device'])

    if opt['cuda']:
        model.cuda()
        loss_weight = loss_weight.cuda()

    # import sys
    # precision, recall, score = eval(val_loader, model, opt, save_res=True)
    # print precision, recall, score
    # sys.exit()

    loss_function = Loss(weight=loss_weight + 1 - loss_weight.mean())
    optimizer = torch.optim.Adam(model.parameters(), lr=opt['lr'])

    logger.info('Start running...')

    steps = 0
    model.train()
    base_epoch = opt['base_epoch']
    for epoch in range(1, opt['epochs'] + 1):
        for i, batch in enumerate(train_loader, 0):
            title, desc, label = batch
            title, desc, label = Variable(title), Variable(desc), Variable(
                label).float()
            if opt['cuda']:
                title, desc, label = title.cuda(), desc.cuda(), label.cuda()

            optimizer.zero_grad()
            logit = model(title, desc)

            loss = loss_function(logit, label)
            loss.backward()
            optimizer.step()

            steps += 1
            if steps % opt['log_interval'] == 0:
                corrects = ((logit.data >
                             opt['threshold']) == (label.data).byte()).sum()
                accuracy = 100.0 * corrects / (opt['batch_size'] *
                                               opt['class_num'])
                log_info = 'Steps[{:>8}] (epoch[{:>2}] / batch[{:>5}]) - loss: {:.6f}, acc: {:.4f} % ({} / {})'.format( \
                                steps, epoch + base_epoch, (i+1), loss.data[0], accuracy, \
                                corrects, opt['batch_size'] * opt['class_num'])
                logger.info(log_info)
                vis.plot('loss', loss.data[0])
                precision, recall, score = eval(batch,
                                                model,
                                                opt,
                                                isBatch=True)
                vis.plot('score', score)
        logger.info('Training epoch {} finished!'.format(epoch + base_epoch))
        precision, recall, score = eval(val_loader, model, opt)
        log_info = 'Epoch[{}] - score: {:.6f} (precision: {:.4f}, recall: {:.4f})'.format( \
                            epoch + base_epoch, score, precision, recall)
        vis.log(log_info)
        save_model(model, model_dir=opt['model_dir'], model_name=opt['model'], \
                    epoch=epoch+base_epoch, score=score)
        if epoch + base_epoch == 2:
            model.opt['static'] = False
        elif epoch + base_epoch == 4:
            for param_group in optimizer.param_groups:
                param_group['lr'] = opt['lr'] * opt['lr_decay']
        elif epoch + base_epoch >= 5:
            if opt['boost']:
                res, truth = eval(val_loader, model, opt, return_res=True)
                ori_score = get_score(cal_res, truth)
                cal_res += res
                cur_score = get_score(cal_res, truth)
                logger.info('Layer {}: {}, Layer {}: {}'.format(
                    opt['base_layer'], ori_score, opt['base_layer'] + 1,
                    cur_score))
                loss_weight = get_loss_weight(cal_res, truth)
                torch.save(
                    cal_res, '{}/{}/layer_{}_cal_res_3.pt'.format(
                        opt['model_dir'], opt['model'], opt['base_layer'] + 1))
                logger.info('Save cal_res successful!')
                torch.save(
                    loss_weight, '{}/{}/layer_{}_loss_weight_3.pt'.format(
                        opt['model_dir'], opt['model'], opt['base_layer'] + 2))
            break
def train_pair(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # prepare data
    train_data = PairSWDataset(config.train_paths, phase='train', useRGB=config.useRGB, usetrans=config.usetrans, balance=config.data_balance)
    valpair_data = PairSWDataset(config.test_paths, phase='val_pair', useRGB=config.useRGB, usetrans=config.usetrans, balance=False)
    print('Training Samples:', train_data.__len__(), 'ValPair Samples:', valpair_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist)

    train_dataloader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
    valpair_dataloader = DataLoader(valpair_data, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

    # prepare model
    model = SiameseNet(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(model, device_ids=[x for x in range(config.num_of_gpu)])

    model.train()

    # criterion and optimizer
    weight_pair = torch.Tensor([1, 1.5])
    vis.log(f'pair loss weight: {weight_pair}')
    print('pair loss weight:', weight_pair)
    weight_pair = weight_pair.cuda()
    pair_criterion = torch.nn.CrossEntropyLoss(weight=weight_pair)

    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.weight_decay)

    # metric
    softmax = functional.softmax
    pair_loss_meter = meter.AverageValueMeter()
    pair_epoch_loss = meter.AverageValueMeter()

    pair_train_cm = meter.ConfusionMeter(config.num_classes)
    # previous_loss = 100
    pair_previous_avg_se = 0

    # train
    if config.parallel:
        if not os.path.exists(os.path.join('checkpoints', model.module.model_name)):
            os.mkdir(os.path.join('checkpoints', model.module.model_name))
    else:
        if not os.path.exists(os.path.join('checkpoints', model.model_name)):
            os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        print(f"epoch: [{epoch+1}/{config.max_epoch}] =============================================")
        pair_train_cm.reset()
        pair_epoch_loss.reset()

        # train
        for i, (image_1, image_2, label_1, label_2, label_res, _, _) in tqdm(enumerate(train_dataloader)):
            pair_loss_meter.reset()

            # prepare input
            image_1 = Variable(image_1)
            image_2 = Variable(image_2)
            target_res = Variable(label_res)

            if config.use_gpu:
                image_1 = image_1.cuda()
                image_2 = image_2.cuda()
                target_res = target_res.cuda()

            # go through the model
            score_1, score_2, score_res = model(image_1, image_2)

            # backpropagate
            optimizer.zero_grad()
            pair_loss = pair_criterion(score_res, target_res)
            pair_loss.backward()
            optimizer.step()

            pair_loss_meter.add(pair_loss.data[0])
            pair_epoch_loss.add(pair_loss.data[0])

            pair_train_cm.add(softmax(score_res, dim=1).data, target_res.data)

            if (i+1) % config.print_freq == 0:
                vis.plot('loss', pair_loss_meter.value()[0])

        # print result
        pair_train_se = [100. * pair_train_cm.value()[0][0] / (pair_train_cm.value()[0][0] + pair_train_cm.value()[0][1]),
                         100. * pair_train_cm.value()[1][1] / (pair_train_cm.value()[1][0] + pair_train_cm.value()[1][1])]
        model.eval()
        pair_val_cm, pair_val_accuracy, pair_val_se = val_pair(model, valpair_dataloader)

        if np.average(pair_val_se) > pair_previous_avg_se:  # 当测试集上的平均sensitivity升高时保存模型
            if config.parallel:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
                save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
                if not os.path.exists(os.path.join('checkpoints', save_model_dir)):
                    os.makedirs(os.path.join('checkpoints', save_model_dir))
                model.module.save(os.path.join('checkpoints', save_model_dir, save_model_name))
            else:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
                save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
                if not os.path.exists(os.path.join('checkpoints', save_model_dir)):
                    os.makedirs(os.path.join('checkpoints', save_model_dir))
                model.save(os.path.join('checkpoints', save_model_dir, save_model_name))
            pair_previous_avg_se = np.average(pair_val_se)

        if epoch+1 == config.max_epoch:  # 保存最后一个模型
            if config.parallel:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
                save_model_name = config.save_model_name.split('.pth')[0]+'_last.pth' if config.save_model_name else model.module.model_name + '_last_model.pth'
            else:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
                save_model_name = config.save_model_name.split('.pth')[0]+'_last.pth' if config.save_model_name else model.model_name + '_last_model.pth'
            if not os.path.exists(os.path.join('checkpoints', save_model_dir)):
                os.makedirs(os.path.join('checkpoints', save_model_dir))
            model.save(os.path.join('checkpoints', save_model_dir, save_model_name))

        vis.plot_many({'epoch_loss': pair_epoch_loss.value()[0],
                       'pair_train_avg_se': np.average(pair_train_se), 'pair_train_se_0': pair_train_se[0], 'pair_train_se_1': pair_train_se[1],
                       'pair_val_avg_se': np.average(pair_val_se), 'pair_val_se_0': pair_val_se[0], 'pair_val_se_1': pair_val_se[1]})
        vis.log(f"epoch: [{epoch+1}/{config.max_epoch}] ===============================================")
        vis.log(f"lr: {lr}, loss: {round(pair_epoch_loss.value()[0], 5)}")
        vis.log(f"pair_train_avg_se: {round(np.average(pair_train_se), 4)}, pair_train_se_0: {round(pair_train_se[0], 4)}, pair_train_se_1: {round(pair_train_se[1], 4)}")
        vis.log(f"pair_val_avg_se: {round(sum(pair_val_se) / len(pair_val_se), 4)}, pair_val_se_0: {round(pair_val_se[0], 4)}, pair_val_se_1: {round(pair_val_se[1], 4)}")
        vis.log(f'pair_train_cm: {pair_train_cm.value()}')
        vis.log(f'pair_val_cm: {pair_val_cm.value()}')
        print("lr:", lr, "loss:", round(pair_epoch_loss.value()[0], 5))
        print('pair_train_avg_se:', round(np.average(pair_train_se), 4), 'pair_train_se_0:', round(pair_train_se[0], 4), 'pair_train_se_1:', round(pair_train_se[1], 4))
        print('pair_val_avg_se:', round(np.average(pair_val_se), 4), 'pair_val_se_0:', round(pair_val_se[0], 4), 'pair_val_se_1:', round(pair_val_se[1], 4))
        print('pair_train_cm:')
        print(pair_train_cm.value())
        print('pair_val_cm:')
        print(pair_val_cm.value())

        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        # previous_loss = loss_meter.value()[0]
        if (epoch+1) % 5 == 0:
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
Beispiel #12
0
def train(**kwargs):
    opt.parse(kwargs)
    if opt.use_visdom:
        vis = Visualizer(opt.env)

    # step 1: configure model
    # model = densenet169(pretrained=True)
    # model = DenseNet169(num_classes=2)
    # model = ResNet152(num_classes=2)
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu:
        print('CUDA MODEL!')
        model.cuda()

    model.train()

    # step 2: data
    train_data = MURA_Dataset(opt.data_root,
                              opt.train_image_paths,
                              train=True,
                              test=False)
    val_data = MURA_Dataset(opt.data_root,
                            opt.test_image_paths,
                            train=False,
                            test=False)

    print('Training images:', len(train_data), 'Validation images:',
          len(val_data))

    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)

    # step 3: criterion and optimizer
    A = 21935
    N = 14873
    weight = t.Tensor([A / (A + N), N / (A + N)])
    if opt.use_gpu:
        weight = weight.cuda()

    criterion = t.nn.CrossEntropyLoss(weight=weight)
    # criterion = FocalLoss(alpha=weight, class_num=2)
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(),
                             lr=lr,
                             weight_decay=opt.weight_decay)

    # step 4: meters
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e10

    # step 5: train

    if not os.path.exists(os.path.join('checkpoints', model.model_name)):
        os.mkdir(os.path.join('checkpoints', model.model_name))
    prefix = time.strftime('%m%d')
    if not os.path.exists(os.path.join('checkpoints', model.model_name,
                                       prefix)):
        os.mkdir(os.path.join('checkpoints', model.model_name, prefix))

    s = t.nn.Softmax()
    for epoch in range(opt.max_epoch):

        loss_meter.reset()
        confusion_matrix.reset()

        for ii, (data, label, _,
                 body_part) in tqdm(enumerate(train_dataloader)):

            # train model
            input = Variable(data)
            target = Variable(label)
            # body_part = Variable(body_part)
            if opt.use_gpu:
                input = input.cuda()
                target = target.cuda()
                # body_part = body_part.cuda()

            optimizer.zero_grad()
            if opt.model.startswith('MultiBranch'):
                score = model(input, body_part)
            else:
                score = model(input)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            # meters update and visualize
            loss_meter.add(loss.data[0])
            confusion_matrix.add(s(Variable(score.data)).data, target.data)

            if ii % opt.print_freq == opt.print_freq - 1:
                if opt.use_visdom:
                    vis.plot('loss', loss_meter.value()[0])
                # print('loss', loss_meter.value()[0])

                # debug
                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()

        ck_name = f'epoch_{epoch}_{str(opt)}.pth'
        model.save(
            os.path.join('checkpoints', model.model_name, prefix, ck_name))
        # model.save()

        # validate and visualize
        val_cm, val_accuracy, val_loss = val(model, val_dataloader)

        cm = confusion_matrix.value()

        if opt.use_visdom:
            vis.plot('val_accuracy', val_accuracy)
            vis.log(
                "epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm},train_acc:{train_acc}, "
                "val_acc:{val_acc}".format(
                    epoch=epoch,
                    loss=loss_meter.value()[0],
                    val_cm=str(val_cm.value()),
                    train_cm=str(confusion_matrix.value()),
                    lr=lr,
                    train_acc=str(100. * (cm[0][0] + cm[1][1]) / (cm.sum())),
                    val_acc=str(100. *
                                (val_cm.value()[0][0] + val_cm.value()[1][1]) /
                                (val_cm.value().sum()))))
        print('val_accuracy: ', val_accuracy)
        print(
            "epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm},train_acc:{train_acc}, "
            "val_acc:{val_acc}".format(
                epoch=epoch,
                loss=loss_meter.value()[0],
                val_cm=str(val_cm.value()),
                train_cm=str(confusion_matrix.value()),
                lr=lr,
                train_acc=100. * (cm[0][0] + cm[1][1]) / (cm.sum()),
                val_acc=100. * (val_cm.value()[0][0] + val_cm.value()[1][1]) /
                (val_cm.value().sum())))

        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            # if val_loss > previous_loss:
            lr = lr * opt.lr_decay
            # 第二种降低学习率的方法:不会有moment等信息的丢失
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        # previous_loss = val_loss
        previous_loss = loss_meter.value()[0]
Beispiel #13
0
def train(**kwargs):
    vis = Visualizer(port=config.vis_port)
    model = nn.DataParallel(ResUNet(1, 1, base=32, depth=6)).cuda().train()

    handle1 = model.module.register_forward_hook(get_output('input_layer'))
    handle3 = model.module.register_forward_hook(get_output('output'))

    if config.checkpoint:
        try:
            if isinstance(model, nn.DataParallel):
                model.module.load(get_pth(model, config.checkpoint))
            else:
                model.load(get_pth(model, config.checkpoint))
            print("Load model successfully")
        except:
            pass

    train_lst, val_lst = data_split_2D(main(),
                                       config.Train_percent,
                                       shuffle=True)
    train_data = UNET2DDataset(train_lst)
    train_dataloader = DataLoader(train_data,
                                  batch_size=config.TrainBatchSize,
                                  shuffle=True,
                                  num_workers=config.Num_workers)

    val_data = UNET2DDataset(val_lst)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.ValBatchSize,
                                shuffle=True)

    train_matrix = SegmentationMetrix()
    lr = config.lr

    # focalDice_criterion = Losses.FocalDice(alpha=config.Loss_alpha, beta=config.Loss_beta).to(config.Device)
    # focal_criterion = Losses.FocalLoss().to(config.Device)
    dice_criterion = Losses.DiceLoss().to(config.Device)
    BCEDice_criterion = Losses.BCEDice().to(config.Device)
    bce_criterion = Losses.BCELoss().to(config.Device)
    criterion = dice_criterion

    optimizer = optim.Adam(params=model.parameters(),
                           lr=lr,
                           weight_decay=config.Weight_decay)
    # scheduler = lr_scheduler.StepLR(optimizer, 10, config.lr_decay)  # *0.1/epoch

    previous_loss = 0
    prev_ValIou = 0
    patient = config.Patient
    debug_patient = config.Debug_Patient
    for epoch in range(config.Max_epoch):
        train_matrix.reset()
        start_time = datetime.now()
        train_acc = 0
        train_prec = 0
        train_sensi = 0
        loss_counter = 0
        for ii, (data, label) in enumerate(tqdm(train_dataloader)):
            input = V(data.to(config.Device, dtype=torch.float))
            mask = V(label.to(config.Device, dtype=torch.float))

            optimizer.zero_grad()
            score = model(input)
            loss = criterion(score, mask)

            # ----------------clip gradient-------------
            # nn.utils.clip_grad_norm(model.parameters(), max_norm=20, norm_type=2.0)

            if torch.isnan(loss):
                # print("Loss is NaN!")
                # pdb.set_trace()
                continue

            loss_counter += loss.item()
            loss.backward()
            try:
                optimizer.step()
            except:
                for param_group in optimizer.param_groups:
                    print(param_group.keys())
                    print([type(value) for value in param_group.values()])
                    print('learning rage', param_group['lr'])
                    print('eps:', param_group['eps'])
                    print('params:', param_group['params'])
                    print('weight_decay:', param_group['weight_dacay'])
                pass

            # train valuation:
            train_matrix.genConfusionMat(score.clone(), label.clone())
            train_acc += train_matrix.pixelAcc()
            train_prec += train_matrix.precision()
            train_sensi += train_matrix.sensitive()

            if ii % config.Print_freq == config.Print_freq - 1:
                vis.log('train loss in epoch:' + criterion.name + ": " +
                        str(loss.item()))
                vis.plot('train loss in epoch', loss.item())
                vis.img_many(features)
                # for name, weight in model.module.named_parameters():
                #     if weight.requires_grad:
                #         print(name, "-grad mean: ", weight.grad.mean(dim=0))
                #         print(name, "-grad min: ", weight.grad.min(dim=0))
                #         print(name, "-grad max: ", weight.grad.max(dim=0))

        end_time = datetime.now()
        h, remainder = divmod((end_time - start_time).seconds, 3600)
        m, s = divmod(remainder, 60)

        # ----------------validation-------------------
        val_loss, val_iou, val_pixelAcc, val_precision, val_sensitive = val(
            model, val_dataloader)
        avg_loss = loss_counter / len(train_dataloader)
        vis.log('val_iou:' + str(val_iou))
        vis.log('val_acc:' + str(val_pixelAcc))
        vis.log('val_precision:' + str(val_precision))
        vis.log('val_sensitive:' + str(val_sensitive))
        vis.plot('val_loss', val_loss)
        vis.plot('val_iou', val_iou)
        vis.plot('val_acc', val_pixelAcc)
        vis.plot('val_precision', val_precision)
        vis.plot('val_sensitive', val_sensitive)
        vis.plot('avg train loss', avg_loss)

        epoch_str = (
            'Epoch: {}, Train Loss: {:.5f}, Train Acc: {:.5f}, Train Mean Precision: {:.5f}, \
        Valid Loss: {:.5f}, Valid Acc: {:.5f}, Valid Mean IU: {:.5f} '.format(
                epoch, avg_loss, train_acc / len(train_dataloader),
                train_prec / len(train_dataloader), val_loss, val_pixelAcc,
                val_iou))
        time_str = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s)
        print(epoch_str + time_str + ' lr: {}'.format(lr))

        if avg_loss >= previous_loss and previous_loss != 0:
            debug_patient -= 1
            if patient == 0:
                patient = config.Patient
                lr = lr * config.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            else:
                patient -= 1

            if debug_patient == 0:
                pdb.set_trace()
                debug_patient = config.Debug_Patient

        previous_loss = avg_loss
        if val_iou > prev_ValIou and prev_ValIou != 0:
            model.module.save()
        prev_ValIou = val_iou

    handle1.remove()
    handle3.remove()
Beispiel #14
0
def train(**kwargs):
    
    # step: configure

    opt._parse(**kwargs)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    vis = Visualizer(env=opt.env)
    
    # step: data  这里是指要从数据加载过程中加载所需要的数据,所有数据只加载一次,也可以通过self的方式获取,

    dataloader = get_dataloader(opt)
    _data = dataloader.dataset.data
    # word2ix, ix2word = _data.word2ix, _data.ix2word    
    # ix2id, id2ix = _data.ix2id, _data.id2ix
    word2ix, ix2word = _data['word2ix'], _data['ix2word']
    ix2id, id2ix, end = _data['ix2id'], _data['id2ix'], _data['end']
    eos_id = word2ix[end]
    
    # step: model 
    # 刚刚看了看作者写的模型,在保存模型的时候把opt也一并保存了,这是要做什么的,貌似是为了在进行生成的时候用的
    # 为了避免以后有漏洞,在这里定义模型的时候输入参数暂且按照作者的来,然后在生成的时候再返回来看各个参数的作用
    # 因为word2ix,ix2word是定义数据集的时候用的,按理来说跟模型没有关系才对
    model = CaptionModel(opt, len(ix2word) )
    if opt.model_ckpt:
        model.load(opt.model_ckpt)
    model.to(device)
    
    # step: meter criterion optimizer
    loss_meter = meter.AverageValueMeter()
    criterion = t.nn.CrossEntropyLoss()
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
    model.save()
    # step: train
    for epoch in range(opt.max_epoch):
        loss_meter.reset()
        
        for ii,(imgs, (captions, lengths), indexes) in tqdm.tqdm(enumerate(dataloader), total = len(dataloader)):
            
            optimizer.zero_grad()
            imgs = imgs.to(device)
            captions =  captions.to(device)
            input_captions = captions[:-1]
            target_captions = pack_padded_sequence(captions, lengths).data # len
            score, _ = model(imgs, input_captions, lengths) # len*vocab
            loss = criterion(score, target_captions)
            loss.backward()
            optimizer.step()
            loss_meter.add(loss.data.item())
            
            # step: visulize
            if(ii+1)%opt.print_freq == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                    
                vis.plot('loss', loss_meter.value()[0])
                # picture+caption
                # indexes在这里用到了,因为要可视化图片,就需要知道当前是第几张图片,而主模型的输入是直接2048维特征。没有直接指示第几行图片,
                # 同时也说明了,有序的重要性,所以在提取图片特征的时候,不是直接读取文件,而是从id2ix中获取,来使得标题和图片都可以从id2ix中找到对应关系
                # 如果是我,我肯定不会想到用id和序号做一个对应关系,说不定直接用列表存储所有图片名称,
                # 用列表终归不如用dict好,因为dict是可以反推回他是第几张图片 ix2id和id2ix,而列表只能是知道第几张图片的位置,不能反推。

                img_path = os.path.join(opt.img_path, ix2id[indexes[0]]) 
                raw_img = Image.open(img_path).convert('RGB')
                raw_img = tv.transforms.ToTensor()(raw_img)
                
                # captions_np = np.array(captions) # zheli shi weile bimian ziji buzhuyi er daozhi jisuantu de cunzai suoyi duiyu moxing de shuru he shuchu zuo qi ta caozuo shi douxian jiangqi xianshi zhuanhua cheng meiyou jisuantu biru detach() biru with t.no_grad() biru t_.data.tolist() biru t_.data biru qita
                raw_caption = captions.data[:,0].tolist()  #
               
                raw_caption = ''.join([ix2word[i] for i in raw_caption])
                # vis.img('raw_img', raw_img, caption=raw_caption)
                
                info = '<br>'.join([ix2id[indexes[0]],raw_caption])
                vis.log(u'raw_caption', info, False)
                results, scores = model.generate(img=imgs.data[0], eos_id=eos_id)
                cap_sentences = [ ''.join([ix2word[ix.item()] for ix in sentence]) for sentence in results]
                info = '<br>'.join(cap_sentences)
                info = '<br>'.join([ix2id[indexes[0]],info])
                vis.log(u'val', info, False)               
 
        model.save()
Beispiel #15
0
def train(**kwargs):
    # 根据命令行参数更新配置
    opt.parse(kwargs)
    vis = Visualizer(opt.env)
    # step1: 模型
    model = getattr(models, opt.model)()
    '''
	model_ft = torchvision.models.vgg16_bn(pretrained = True)
	pretrained_dict = model_ft.state_dict()
	model_dict = model.state_dict()
	# 将pretrained_dict里不属于model_dict的键剔除掉
	pretrained_dict =  {k: v for k, v in pretrained_dict.items() 
					if k in model_dict}
	model_dict.update(pretrained_dict)
	model.load_state_dict(model_dict)
	'''
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu:
        model.cuda()
        summary(model, (3, 224, 224))
    print(opt)
    # step2: 数据
    train_data = myData(
        filelists=train_filelists,
        #transform = data_transforms['train'],
        scale=opt.cropscale,
        transform=None,
        test=False,
        data_source='none')
    val_data = myData(
        filelists=test_filelists,
        #transform =data_transforms['val'],
        transform=None,
        scale=opt.cropscale,
        test=False,
        data_source='none')

    train_loader = DataLoader(dataset=train_data,
                              batch_size=opt.batch_size,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=opt.batch_size // 2,
                            shuffle=False)

    dataloaders = {'train': train_loader, 'val': val_loader}
    dataset_sizes = {'train': len(train_data), 'val': len(val_data)}

    # step3: 目标函数和优化器
    criterion = FocalLoss(2)
    #criterion = torch.nn.CrossEntropyLoss()
    lr = opt.lr
    #optimizer = t.optim.Adam(model.parameters(),
    #                       lr = lr,
    #                       weight_decay = opt.weight_decay)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.lr,
                                momentum=0.9,
                                weight_decay=opt.weight_decay)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)
    #set learning rate every 10 epoch decrease 10%
    # step4: 统计指标:平滑处理之后的损失,还有混淆矩阵

    confusion_matrix = meter.ConfusionMeter(2)
    train_loss = meter.AverageValueMeter()  #为了可视化增加的内容
    val_loss = meter.AverageValueMeter()
    train_acc = meter.AverageValueMeter()  #为了可视化增加的内容
    val_acc = meter.AverageValueMeter()
    previous_loss = 1e100
    best_acc = 0.0
    # 训练
    for epoch in range(opt.max_epoch):
        print('Epoch {}/{}'.format(epoch, opt.max_epoch - 1))
        print('-' * 10)
        train_loss.reset()
        train_acc.reset()
        running_loss = 0.0
        running_corrects = 0
        exp_lr_scheduler.step()
        for step, batch in enumerate(
                tqdm(train_loader, desc='Train On Anti-spoofing',
                     unit='batch')):
            inputs, labels = batch

            if opt.use_gpu:
                inputs = Variable(inputs.cuda())
                labels = Variable(labels.cuda())
            else:
                inputs = Variable(inputs)
                lables = Variable(labels)
            optimizer.zero_grad()  #zero the parameter gradients
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                #print(outputs.shape)
                _, preds = torch.max(outputs, 1)

                loss0 = criterion(outputs, labels)
                loss = loss0
                loss.backward()  #backward of gradient
                optimizer.step()  #strategy to drop
                if step % 20 == 0:
                    pass
                    #print('epoch:%d/%d step:%d/%d loss: %.4f loss0: %.4f loss1: %.4f'%(epoch, opt.max_epoch, step, len(train_loader),
                    #loss.item(),loss0.item(),loss1.item()))
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            '''
			if step%opt.print_freq==opt.print_freq-1:
				vis.plot('loss', train_loss.value()[0])
			   
			   # 如果需要的话,进入debug模式
			   if os.path.exists(opt.debug_file):
				   import ipdb;
				   ipdb.set_trace()	
			'''
        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / float(dataset_sizes['train'])
        print('Train Loss: {:.8f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
        train_loss.add(epoch_loss)
        train_acc.add(epoch_acc)

        val_loss.reset()
        val_acc.reset()
        val_cm, v_loss, v_accuracy = val(model, val_loader,
                                         dataset_sizes['val'])
        print('Val Loss: {:.8f} Acc: {:.4f}'.format(v_loss, v_accuracy))
        val_loss.add(v_loss)
        val_acc.add(v_accuracy)


        vis.plot_many_stack({'train_loss':train_loss.value()[0],\
            'val_loss':val_loss.value()[0]},win_name ="Loss")
        vis.plot_many_stack({'train_acc':train_acc.value()[0],\
            'val_acc':val_acc.value()[0]},win_name = 'Acc')
        vis.log("epoch:{epoch},lr:{lr},\
				train_loss:{train_loss},train_acc:{train_acc},\
				val_loss:{val_loss},val_acc:{val_acc},\
				train_cm:{train_cm},val_cm:{val_cm}".format(
            epoch=epoch,
            train_loss=train_loss.value()[0],
            train_acc=train_acc.value()[0],
            val_loss=val_loss.value()[0],
            val_acc=val_acc.value()[0],
            train_cm=str(confusion_matrix.value()),
            val_cm=str(val_cm.value()),
            lr=lr))
        '''
		if v_loss > previous_loss:          
			lr = lr * opt.lr_decay
			for param_group in optimizer.param_groups:
				param_group['lr'] = lr
		'''
        vis.plot_many_stack({'lr': lr}, win_name='lr')
        previous_loss = val_loss.value()[0]
        if v_accuracy > best_acc:
            best_acc = v_accuracy
            best_acc_epoch = epoch
            #best_model_wts = model.state_dict()
            os.system('mkdir -p %s' % (os.path.join('checkpoints', opt.model)))
            model.save(name='checkpoints/' + opt.model + '/' + str(epoch) +
                       '.pth')
            print('Epoch: {:d} Val Loss: {:.8f} Acc: {:.4f}'.format(
                epoch, v_loss, v_accuracy),
                  file=open('result/val.txt', 'a'))
        #model.load_state_dict(best_model_wts)
    print('Best val Epoch: {},Best val Acc: {:4f}'.format(
        best_acc_epoch, best_acc))
def multitask_train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)

    # prepare data
    train_data = MultiLabel_Dataset(config.data_root,
                                    config.train_paths,
                                    phase='train',
                                    balance=config.data_balance)
    val_data = MultiLabel_Dataset(config.data_root,
                                  config.test_paths,
                                  phase='val',
                                  balance=config.data_balance)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    model = MultiTask_DenseNet121(num_classes=2)  # 每一个分支都是2分类
    # model = CheXPre_MultiTask_DenseNet121(num_classes=2)  # 每一个分支都是2分类

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()

    model.train()

    # criterion and optimizer
    # F, T1, T2 = 3500, 3078, 3565  # 权重,分别是没病,成骨型,溶骨型的图片数量
    # weight_1 = torch.FloatTensor([T1/(F+T1+T2), (F+T2)/(F+T1+T2)]).cuda()  # weight也需要用cuda的
    # weight_2 = torch.FloatTensor([T2/(F+T1+T2), (F+T1)/(F+T1+T2)]).cuda()
    # criterion_1 = torch.nn.CrossEntropyLoss(weight=weight_1)
    # criterion_2 = torch.nn.CrossEntropyLoss(weight=weight_2)
    criterion_1 = torch.nn.CrossEntropyLoss()
    criterion_2 = torch.nn.CrossEntropyLoss()
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # metrics
    softmax = functional.softmax
    loss_meter_1 = meter.AverageValueMeter()
    loss_meter_2 = meter.AverageValueMeter()
    loss_meter_total = meter.AverageValueMeter()
    train_cm_1 = meter.ConfusionMeter(2)  # 每个支路都是二分类,整体是三分类
    train_cm_2 = meter.ConfusionMeter(2)
    train_cm_total = meter.ConfusionMeter(3)
    previous_loss = 100
    previous_acc = 0

    # train
    if not os.path.exists(os.path.join('checkpoints', model.model_name)):
        os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        loss_meter_1.reset()
        loss_meter_2.reset()
        loss_meter_total.reset()
        train_cm_1.reset()
        train_cm_2.reset()
        train_cm_total.reset()

        # train
        for i, (image, label_1, label_2, label,
                image_path) in tqdm(enumerate(train_dataloader)):
            # prepare input
            img = Variable(image)
            target_1 = Variable(label_1)
            target_2 = Variable(label_2)
            target = Variable(label)
            if config.use_gpu:
                img = img.cuda()
                target_1 = target_1.cuda()
                target_2 = target_2.cuda()
                target = target.cuda()

            # go through the model
            score_1, score_2 = model(img)

            # backpropagate
            optimizer.zero_grad()
            loss_1 = criterion_1(score_1, target_1)
            loss_2 = criterion_2(score_2, target_2)
            loss = loss_1 + loss_2
            # loss.backward()
            # optimizer.step()
            loss_1.backward(
                retain_graph=True)  # 这里将两个loss相加后回传的效果不太好,反而是分别回传效果更好
            optimizer.step()  # 可能的原因是分别回传时的momentum算了两次,更容易突破局部最优解
            loss_2.backward()
            optimizer.step()

            # calculate loss and confusion matrix
            loss_meter_1.add(loss_1.data[0])
            loss_meter_2.add(loss_2.data[0])
            loss_meter_total.add(loss.data[0])

            p_1, p_2 = softmax(score_1, dim=1), softmax(score_2, dim=1)
            c = []

            # -----------------------------------------------------------------------
            for j in range(p_1.data.size()[0]):  # 将两个支路合并得到最终的预测结果
                if p_1.data[j][1] < 0.5 and p_2.data[j][1] < 0.5:
                    c.append([1, 0, 0])
                else:
                    if p_1.data[j][1] > p_2.data[j][1]:
                        c.append([0, 1, 0])
                    else:
                        c.append([0, 0, 1])
            # -----------------------------------------------------------------------

            train_cm_1.add(p_1.data, target_1.data)
            train_cm_2.add(p_2.data, target_2.data)
            train_cm_total.add(torch.FloatTensor(c), target.data)

            if i % config.print_freq == config.print_freq - 1:
                vis.plot_many({
                    'loss_1': loss_meter_1.value()[0],
                    'loss_2': loss_meter_2.value()[0],
                    'loss_total': loss_meter_total.value()[0]
                })
                print('loss_1:',
                      loss_meter_1.value()[0], 'loss_2:',
                      loss_meter_2.value()[0], 'loss_total:',
                      loss_meter_total.value()[0])

        # print result
        train_accuracy_1 = 100. * sum(
            [train_cm_1.value()[c][c]
             for c in range(2)]) / train_cm_1.value().sum()
        train_accuracy_2 = 100. * sum(
            [train_cm_2.value()[c][c]
             for c in range(2)]) / train_cm_2.value().sum()
        train_accuracy_total = 100. * sum(
            [train_cm_total.value()[c][c]
             for c in range(3)]) / train_cm_total.value().sum()

        val_cm_1, val_accuracy_1, val_loss_1, val_cm_2, val_accuracy_2, val_loss_2, val_cm_total, val_accuracy_total, val_loss_total = multitask_val(
            model, val_dataloader)

        if val_accuracy_total > previous_acc:
            if config.save_model_name:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 config.save_model_name))
            else:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 model.model_name + '_best_model.pth'))
            previous_acc = val_accuracy_total

        vis.plot_many({
            'train_accuracy_1': train_accuracy_1,
            'val_accuracy_1': val_accuracy_1,
            'train_accuracy_2': train_accuracy_2,
            'val_accuracy_2': val_accuracy_2,
            'total_train_accuracy': train_accuracy_total,
            'total_val_accuracy': val_accuracy_total
        })
        vis.log(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss_1: {loss_1}, loss_2: {loss_2}, loss_total: {loss_total}"
            .format(epoch=epoch + 1,
                    total_epoch=config.max_epoch,
                    lr=lr,
                    loss_1=loss_meter_1.value()[0],
                    loss_2=loss_meter_2.value()[0],
                    loss_total=loss_meter_total.value()[0]))
        vis.log('train_cm_1:' + str(train_cm_1.value()) + ' train_cm_2:' +
                str(train_cm_2.value()) + ' train_cm_total:' +
                str(train_cm_total.value()))
        vis.log('val_cm_1:' + str(val_cm_1.value()) + ' val_cm_2:' +
                str(val_cm_2.value()) + ' val_cm_total:' +
                str(val_cm_total.value()))

        print('train_accuracy_1:', train_accuracy_1, 'val_accuracy_1:',
              val_accuracy_1, 'train_accuracy_2:', train_accuracy_2,
              'val_accuracy_2:', val_accuracy_2, 'total_train_accuracy:',
              train_accuracy_total, 'total_val_accuracy:', val_accuracy_total)
        print(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss_1: {loss_1}, loss_2: {loss_2}, loss_total: {loss_total}"
            .format(epoch=epoch + 1,
                    total_epoch=config.max_epoch,
                    lr=lr,
                    loss_1=loss_meter_1.value()[0],
                    loss_2=loss_meter_2.value()[0],
                    loss_total=loss_meter_total.value()[0]))
        print('train_cm_1:\n' + str(train_cm_1.value()) + '\ntrain_cm_2:\n' +
              str(train_cm_2.value()) + '\ntrain_cm_total:\n' +
              str(train_cm_total.value()))
        print('val_cm_1:\n' + str(val_cm_1.value()) + '\nval_cm_2:\n' +
              str(val_cm_2.value()) + '\nval_cm_total:\n' +
              str(val_cm_total.value()))

        # update learning rate
        if loss_meter_total.value()[0] > previous_loss:  # 可以考虑分别用两支的loss来判断
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter_total.value()[0]
def train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data = SlideWindowDataset(config.train_paths,
                                    phase='train',
                                    useRGB=config.useRGB,
                                    usetrans=config.usetrans,
                                    balance=config.data_balance)
    val_data = SlideWindowDataset(config.test_paths,
                                  phase='val',
                                  useRGB=config.useRGB,
                                  usetrans=config.usetrans,
                                  balance=False)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist)

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    model = UNet_Classifier(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
        print('Model loaded')
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1, 1])
    # weight = torch.Tensor([dist['1']/(dist['0']+dist['1']), dist['0']/(dist['0']+dist['1'])])  # weight需要将二者反过来,多于二分类可以取倒数
    # weight = torch.Tensor([1, 3.5])
    # weight = torch.Tensor([1, 5])
    weight = torch.Tensor([1, 7])

    vis.log(f'loss weight: {weight}')
    print('loss weight:', weight)
    weight = weight.cuda()
    criterion = torch.nn.CrossEntropyLoss(weight=weight)
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    softmax = functional.softmax
    loss_meter_edge = meter.AverageValueMeter()
    epoch_loss_edge = meter.AverageValueMeter()
    loss_meter_cls = meter.AverageValueMeter()
    epoch_loss_cls = meter.AverageValueMeter()
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)

    # ====================================== Saving and Recording Configuration =================================
    previous_auc = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    process_record = {
        'epoch_loss': [],
        'epoch_loss_edge': [],
        'epoch_loss_cls': [],
        'train_avg_se': [],
        'train_se_0': [],
        'train_se_1': [],
        'val_avg_se': [],
        'val_se_0': [],
        'val_se_1': [],
        'AUC': [],
        'DICE': []
    }  # 用于记录实验过程中的曲线,便于画曲线图

    # ================================================== Training ===============================================
    for epoch in range(config.max_epoch):
        print(
            f"epoch: [{epoch + 1}/{config.max_epoch}] {config.save_model_name[:-4]} =================================="
        )
        train_cm.reset()
        epoch_loss.reset()
        dice = []

        # ****************************************** train ****************************************
        model.train()
        for i, (image, label, edge_mask,
                image_path) in tqdm(enumerate(train_dataloader)):
            loss_meter.reset()

            # ------------------------------------ prepare input ------------------------------------
            if config.use_gpu:
                image = image.cuda()
                label = label.cuda()
                edge_mask = edge_mask.cuda()

            # ---------------------------------- go through the model --------------------------------
            score, score_mask = model(x=image)

            # ----------------------------------- backpropagate -------------------------------------
            optimizer.zero_grad()

            # 分类loss
            loss_cls = criterion(score, label)
            # 对Edge包含pixel加loss
            log_prob_mask = functional.logsigmoid(score_mask)
            count_edge = torch.sum(edge_mask, dim=(1, 2, 3), keepdim=True)
            loss_edge = -1 * torch.mean(
                torch.sum(
                    edge_mask * log_prob_mask, dim=(1, 2, 3), keepdim=True) /
                (count_edge + 1e-8))

            # 对非Edge包含pixel加loss
            r_prob_mask = torch.Tensor([1.0
                                        ]).cuda() - torch.sigmoid(score_mask)
            r_edge_mask = torch.Tensor([1.0]).cuda() - edge_mask
            log_rprob_mask = torch.log(r_prob_mask + 1e-5)
            count_redge = torch.sum(r_edge_mask, dim=(1, 2, 3), keepdim=True)
            loss_redge = -1 * torch.mean(
                torch.sum(r_edge_mask * log_rprob_mask,
                          dim=(1, 2, 3),
                          keepdim=True) / (count_redge + 1e-8))

            # 权重按照前景和背景的像素点数量来算
            w1 = torch.sum(count_edge).item() / (torch.sum(count_edge).item() +
                                                 torch.sum(count_redge).item())
            w2 = torch.sum(count_redge).item() / (
                torch.sum(count_edge).item() + torch.sum(count_redge).item())
            loss = loss_cls + w1 * loss_edge + w2 * loss_redge

            loss.backward()
            optimizer.step()

            # ------------------------------------ record loss ------------------------------------
            loss_meter_edge.add((w1 * loss_edge + w2 * loss_redge).item())
            epoch_loss_edge.add((w1 * loss_edge + w2 * loss_redge).item())
            loss_meter_cls.add(loss_cls.item())
            epoch_loss_cls.add(loss_cls.item())
            loss_meter.add(loss.item())
            epoch_loss.add(loss.item())
            train_cm.add(softmax(score, dim=1).detach(), label.detach())
            dice.append(
                dice_coeff(input=(score_mask > 0.5).float(),
                           target=edge_mask[:, 0, :, :]).item())

            if (i + 1) % config.print_freq == 0:
                vis.plot_many({
                    'loss': loss_meter.value()[0],
                    'loss_edge': loss_meter_edge.value()[0],
                    'loss_cls': loss_meter_cls.value()[0]
                })

        train_se = [
            100. * train_cm.value()[0][0] /
            (train_cm.value()[0][0] + train_cm.value()[0][1]),
            100. * train_cm.value()[1][1] /
            (train_cm.value()[1][0] + train_cm.value()[1][1])
        ]
        train_dice = sum(dice) / len(dice)

        # *************************************** validate ***************************************
        model.eval()
        if (epoch + 1) % 1 == 0:
            Best_T, val_cm, val_spse, val_accuracy, AUC, val_dice = val(
                model, val_dataloader)

            # ------------------------------------ save model ------------------------------------
            if AUC > previous_auc and epoch + 1 > 5:  # 5个epoch之后,当测试集上的平均sensitivity升高时保存模型
                if config.parallel:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4]))
                    model.module.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name[:-4], save_model_name))
                else:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4]))
                    model.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name[:-4], save_model_name))
                previous_auc = AUC
                save_epoch = epoch + 1

            # ---------------------------------- recond and print ---------------------------------
            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['epoch_loss_edge'].append(
                epoch_loss_edge.value()[0])
            process_record['epoch_loss_cls'].append(epoch_loss_cls.value()[0])
            process_record['train_avg_se'].append(np.average(train_se))
            process_record['train_se_0'].append(train_se[0])
            process_record['train_se_1'].append(train_se[1])
            process_record['val_avg_se'].append(np.average(val_spse))
            process_record['val_se_0'].append(val_spse[0])
            process_record['val_se_1'].append(val_spse[1])
            process_record['AUC'].append(AUC)
            process_record['DICE'].append(val_dice)

            vis.plot_many({
                'epoch_loss': epoch_loss.value()[0],
                'epoch_loss_edge': epoch_loss_edge.value()[0],
                'epoch_loss_cls': epoch_loss_cls.value()[0],
                'train_avg_se': np.average(train_se),
                'train_se_0': train_se[0],
                'train_se_1': train_se[1],
                'val_avg_se': np.average(val_spse),
                'val_se_0': val_spse[0],
                'val_se_1': val_spse[1],
                'AUC': AUC,
                'train_dice': train_dice,
                'val_dice': val_dice
            })
            vis.log(
                f"epoch: [{epoch + 1}/{config.max_epoch}] ==============================================="
            )
            vis.log(
                f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}"
            )
            vis.log(
                f"train_avg_se: {round(np.average(train_se), 4)}, train_se_0: {round(train_se[0], 4)}, train_se_1: {round(train_se[1], 4)}"
            )
            vis.log(f"train_dice: {round(train_dice, 4)}")
            vis.log(
                f"val_avg_se: {round(sum(val_spse) / len(val_spse), 4)}, val_se_0: {round(val_spse[0], 4)}, val_se_1: {round(val_spse[1], 4)}"
            )
            vis.log(f"val_dice: {round(val_dice, 4)}")
            vis.log(f"AUC: {AUC}")
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'Best Threshold: {Best_T}')
            vis.log(f'val_cm: {val_cm}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                  round(epoch_loss.value()[0], 5))
            print('train_avg_se:', round(np.average(train_se), 4),
                  'train_se_0:', round(train_se[0], 4), 'train_se_1:',
                  round(train_se[1], 4))
            print('train_dice:', train_dice)
            print('val_avg_se:', round(np.average(val_spse), 4), 'val_se_0:',
                  round(val_spse[0], 4), 'val_se_1:', round(val_spse[1], 4))
            print('val_dice:', val_dice)
            print('AUC:', AUC)
            print('train_cm:')
            print(train_cm.value())
            print('Best Threshold:', Best_T, 'val_cm:')
            print(val_cm)

            # ------------------------------------ save record ------------------------------------
            if os.path.exists(
                    os.path.join('checkpoints', save_model_dir,
                                 save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4],
                                             'process_record.json'),
                           content=process_record)
        # if (epoch+1) % 20 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)