示例#1
0
def train(opt):
    seq = iaa.Sequential([
        iaa.CropToFixedSize(opt.fineSize, opt.fineSize),
    ])
    dataset_train = ImageDataset(opt.source_root_train,
                                 opt.gt_root_train,
                                 transform=seq)
    dataset_test = ImageDataset(opt.source_root_test,
                                opt.gt_root_test,
                                transform=seq)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=opt.batchSize,
                                  shuffle=True,
                                  num_workers=opt.nThreads)
    dataloader_test = DataLoader(dataset_test,
                                 batch_size=opt.batchSize,
                                 shuffle=False,
                                 num_workers=opt.nThreads)
    model = StainNet(opt.input_nc, opt.output_nc, opt.n_layer, opt.channels)
    model = nn.DataParallel(model).cuda()
    optimizer = SGD(model.parameters(), lr=opt.lr)
    loss_function = torch.nn.L1Loss()
    lrschedulr = lr_scheduler.CosineAnnealingLR(optimizer, opt.epoch)
    vis = Visualizer(env=opt.name)
    best_psnr = 0
    for i in range(opt.epoch):
        for j, (source_image,
                target_image) in tqdm(enumerate(dataloader_train)):
            target_image = target_image.cuda()
            source_image = source_image.cuda()
            output = model(source_image)
            loss = loss_function(output, target_image)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (j + 1) % opt.display_freq == 0:
                vis.plot("loss", float(loss))
                vis.img("target image", target_image[0] * 0.5 + 0.5)
                vis.img("source image", source_image[0] * 0.5 + 0.5)
                vis.img("output", (output[0] * 0.5 + 0.5).clamp(0, 1))
        if (i + 1) % 5 == 0:
            test_result = test(model, dataloader_test)
            vis.plot_many(test_result)
            if best_psnr < test_result["psnr"]:
                save_path = "{}/{}_best_psnr_layer{}_ch{}.pth".format(
                    opt.checkpoints_dir, opt.name, opt.n_layer, opt.channels)
                best_psnr = test_result["psnr"]
                torch.save(model.module.state_dict(), save_path)
                print(save_path, test_result)
        lrschedulr.step()
        print("lrschedulr=", lrschedulr.get_last_lr())
def multitask_train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)

    # prepare data
    train_data = MultiLabel_Dataset(config.data_root,
                                    config.train_paths,
                                    phase='train',
                                    balance=config.data_balance)
    val_data = MultiLabel_Dataset(config.data_root,
                                  config.test_paths,
                                  phase='val',
                                  balance=config.data_balance)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    model = MultiTask_DenseNet121(num_classes=2)  # 每一个分支都是2分类
    # model = CheXPre_MultiTask_DenseNet121(num_classes=2)  # 每一个分支都是2分类

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()

    model.train()

    # criterion and optimizer
    # F, T1, T2 = 3500, 3078, 3565  # 权重,分别是没病,成骨型,溶骨型的图片数量
    # weight_1 = torch.FloatTensor([T1/(F+T1+T2), (F+T2)/(F+T1+T2)]).cuda()  # weight也需要用cuda的
    # weight_2 = torch.FloatTensor([T2/(F+T1+T2), (F+T1)/(F+T1+T2)]).cuda()
    # criterion_1 = torch.nn.CrossEntropyLoss(weight=weight_1)
    # criterion_2 = torch.nn.CrossEntropyLoss(weight=weight_2)
    criterion_1 = torch.nn.CrossEntropyLoss()
    criterion_2 = torch.nn.CrossEntropyLoss()
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # metrics
    softmax = functional.softmax
    loss_meter_1 = meter.AverageValueMeter()
    loss_meter_2 = meter.AverageValueMeter()
    loss_meter_total = meter.AverageValueMeter()
    train_cm_1 = meter.ConfusionMeter(2)  # 每个支路都是二分类,整体是三分类
    train_cm_2 = meter.ConfusionMeter(2)
    train_cm_total = meter.ConfusionMeter(3)
    previous_loss = 100
    previous_acc = 0

    # train
    if not os.path.exists(os.path.join('checkpoints', model.model_name)):
        os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        loss_meter_1.reset()
        loss_meter_2.reset()
        loss_meter_total.reset()
        train_cm_1.reset()
        train_cm_2.reset()
        train_cm_total.reset()

        # train
        for i, (image, label_1, label_2, label,
                image_path) in tqdm(enumerate(train_dataloader)):
            # prepare input
            img = Variable(image)
            target_1 = Variable(label_1)
            target_2 = Variable(label_2)
            target = Variable(label)
            if config.use_gpu:
                img = img.cuda()
                target_1 = target_1.cuda()
                target_2 = target_2.cuda()
                target = target.cuda()

            # go through the model
            score_1, score_2 = model(img)

            # backpropagate
            optimizer.zero_grad()
            loss_1 = criterion_1(score_1, target_1)
            loss_2 = criterion_2(score_2, target_2)
            loss = loss_1 + loss_2
            # loss.backward()
            # optimizer.step()
            loss_1.backward(
                retain_graph=True)  # 这里将两个loss相加后回传的效果不太好,反而是分别回传效果更好
            optimizer.step()  # 可能的原因是分别回传时的momentum算了两次,更容易突破局部最优解
            loss_2.backward()
            optimizer.step()

            # calculate loss and confusion matrix
            loss_meter_1.add(loss_1.data[0])
            loss_meter_2.add(loss_2.data[0])
            loss_meter_total.add(loss.data[0])

            p_1, p_2 = softmax(score_1, dim=1), softmax(score_2, dim=1)
            c = []

            # -----------------------------------------------------------------------
            for j in range(p_1.data.size()[0]):  # 将两个支路合并得到最终的预测结果
                if p_1.data[j][1] < 0.5 and p_2.data[j][1] < 0.5:
                    c.append([1, 0, 0])
                else:
                    if p_1.data[j][1] > p_2.data[j][1]:
                        c.append([0, 1, 0])
                    else:
                        c.append([0, 0, 1])
            # -----------------------------------------------------------------------

            train_cm_1.add(p_1.data, target_1.data)
            train_cm_2.add(p_2.data, target_2.data)
            train_cm_total.add(torch.FloatTensor(c), target.data)

            if i % config.print_freq == config.print_freq - 1:
                vis.plot_many({
                    'loss_1': loss_meter_1.value()[0],
                    'loss_2': loss_meter_2.value()[0],
                    'loss_total': loss_meter_total.value()[0]
                })
                print('loss_1:',
                      loss_meter_1.value()[0], 'loss_2:',
                      loss_meter_2.value()[0], 'loss_total:',
                      loss_meter_total.value()[0])

        # print result
        train_accuracy_1 = 100. * sum(
            [train_cm_1.value()[c][c]
             for c in range(2)]) / train_cm_1.value().sum()
        train_accuracy_2 = 100. * sum(
            [train_cm_2.value()[c][c]
             for c in range(2)]) / train_cm_2.value().sum()
        train_accuracy_total = 100. * sum(
            [train_cm_total.value()[c][c]
             for c in range(3)]) / train_cm_total.value().sum()

        val_cm_1, val_accuracy_1, val_loss_1, val_cm_2, val_accuracy_2, val_loss_2, val_cm_total, val_accuracy_total, val_loss_total = multitask_val(
            model, val_dataloader)

        if val_accuracy_total > previous_acc:
            if config.save_model_name:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 config.save_model_name))
            else:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 model.model_name + '_best_model.pth'))
            previous_acc = val_accuracy_total

        vis.plot_many({
            'train_accuracy_1': train_accuracy_1,
            'val_accuracy_1': val_accuracy_1,
            'train_accuracy_2': train_accuracy_2,
            'val_accuracy_2': val_accuracy_2,
            'total_train_accuracy': train_accuracy_total,
            'total_val_accuracy': val_accuracy_total
        })
        vis.log(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss_1: {loss_1}, loss_2: {loss_2}, loss_total: {loss_total}"
            .format(epoch=epoch + 1,
                    total_epoch=config.max_epoch,
                    lr=lr,
                    loss_1=loss_meter_1.value()[0],
                    loss_2=loss_meter_2.value()[0],
                    loss_total=loss_meter_total.value()[0]))
        vis.log('train_cm_1:' + str(train_cm_1.value()) + ' train_cm_2:' +
                str(train_cm_2.value()) + ' train_cm_total:' +
                str(train_cm_total.value()))
        vis.log('val_cm_1:' + str(val_cm_1.value()) + ' val_cm_2:' +
                str(val_cm_2.value()) + ' val_cm_total:' +
                str(val_cm_total.value()))

        print('train_accuracy_1:', train_accuracy_1, 'val_accuracy_1:',
              val_accuracy_1, 'train_accuracy_2:', train_accuracy_2,
              'val_accuracy_2:', val_accuracy_2, 'total_train_accuracy:',
              train_accuracy_total, 'total_val_accuracy:', val_accuracy_total)
        print(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss_1: {loss_1}, loss_2: {loss_2}, loss_total: {loss_total}"
            .format(epoch=epoch + 1,
                    total_epoch=config.max_epoch,
                    lr=lr,
                    loss_1=loss_meter_1.value()[0],
                    loss_2=loss_meter_2.value()[0],
                    loss_total=loss_meter_total.value()[0]))
        print('train_cm_1:\n' + str(train_cm_1.value()) + '\ntrain_cm_2:\n' +
              str(train_cm_2.value()) + '\ntrain_cm_total:\n' +
              str(train_cm_total.value()))
        print('val_cm_1:\n' + str(val_cm_1.value()) + '\nval_cm_2:\n' +
              str(val_cm_2.value()) + '\nval_cm_total:\n' +
              str(val_cm_total.value()))

        # update learning rate
        if loss_meter_total.value()[0] > previous_loss:  # 可以考虑分别用两支的loss来判断
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter_total.value()[0]
def train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)

    # prepare data
    train_data = Vertebrae_Dataset(config.data_root,
                                   config.train_paths,
                                   phase='train',
                                   balance=config.data_balance)
    val_data = Vertebrae_Dataset(config.data_root,
                                 config.test_paths,
                                 phase='val',
                                 balance=config.data_balance)
    # train_data = FrameDiff_Dataset(config.data_root, config.train_paths, phase='train', balance=config.data_balance)
    # val_data = FrameDiff_Dataset(config.data_root, config.test_paths, phase='val', balance=config.data_balance)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    # model = ResNet34(num_classes=config.num_classes)
    # model = DenseNet121(num_classes=config.num_classes)
    # model = CheXPre_DenseNet121(num_classes=config.num_classes)
    # model = MultiResDenseNet121(num_classes=config.num_classes)
    # model = Vgg19(num_classes=config.num_classes)
    model = MultiResVgg19(num_classes=config.num_classes)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    model.train()

    # criterion and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # metric
    softmax = functional.softmax
    loss_meter = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)
    previous_loss = 100
    previous_acc = 0

    # train
    if config.parallel:
        if not os.path.exists(
                os.path.join('checkpoints', model.module.model_name)):
            os.mkdir(os.path.join('checkpoints', model.module.model_name))
    else:
        if not os.path.exists(os.path.join('checkpoints', model.model_name)):
            os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        loss_meter.reset()
        train_cm.reset()

        # train
        for i, (image, label, image_path) in tqdm(enumerate(train_dataloader)):
            # prepare input
            img = Variable(image)
            target = Variable(label)
            if config.use_gpu:
                img = img.cuda()
                target = target.cuda()

            # go through the model
            score = model(img)

            # backpropagate
            optimizer.zero_grad()
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.data[0])
            train_cm.add(softmax(score, dim=1).data, target.data)

            if i % config.print_freq == config.print_freq - 1:
                vis.plot('loss', loss_meter.value()[0])
                print('loss', loss_meter.value()[0])

        # print result
        train_accuracy = 100. * sum(
            [train_cm.value()[c][c]
             for c in range(config.num_classes)]) / train_cm.value().sum()
        val_cm, val_accuracy, val_loss = val(model, val_dataloader)

        if val_accuracy > previous_acc:
            if config.parallel:
                if config.save_model_name:
                    model.save(
                        os.path.join('checkpoints', model.module.model_name,
                                     config.save_model_name))
                else:
                    model.save(
                        os.path.join(
                            'checkpoints', model.module.model_name,
                            model.module.model_name + '_best_model.pth'))
            else:
                if config.save_model_name:
                    model.save(
                        os.path.join('checkpoints', model.model_name,
                                     config.save_model_name))
                else:
                    model.save(
                        os.path.join('checkpoints', model.model_name,
                                     model.model_name + '_best_model.pth'))
            previous_acc = val_accuracy

        vis.plot_many({
            'train_accuracy': train_accuracy,
            'val_accuracy': val_accuracy
        })
        vis.log(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
                epoch=epoch + 1,
                total_epoch=config.max_epoch,
                lr=lr,
                loss=loss_meter.value()[0]))
        vis.log('train_cm:')
        vis.log(train_cm.value())
        vis.log('val_cm')
        vis.log(val_cm.value())
        print('train_accuracy:', train_accuracy, 'val_accuracy:', val_accuracy)
        print("epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
            epoch=epoch + 1,
            total_epoch=config.max_epoch,
            lr=lr,
            loss=loss_meter.value()[0]))
        print('train_cm:')
        print(train_cm.value())
        print('val_cm:')
        print(val_cm.value())

        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter.value()[0]
示例#4
0
def train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)

    train_roots = [
        os.path.join(config.data_root, 'Features_Normal'),
        os.path.join(config.data_root, 'Features_Horizontal'),
        os.path.join(config.data_root, 'Features_Vertical'),
        os.path.join(config.data_root, 'Features_Horizontal_Vertical')
    ]
    val_roots = [os.path.join(config.data_root, 'Features')]

    train_data = Feature_Dataset(train_roots,
                                 config.train_paths,
                                 phase='train',
                                 balance=config.data_balance)
    val_data = Feature_Dataset(val_roots,
                               config.test_paths,
                               phase='val',
                               balance=config.data_balance)
    print('Training Feature Lists:', train_data.__len__(),
          'Validation Feature Lists:', val_data.__len__())

    train_dataloader = DataLoader(train_data,
                                  batch_size=1,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=1,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    model = BiLSTM_CRF(tag_to_ix=tag_to_ix,
                       embedding_dim=EMBEDDING_DIM,
                       hidden_dim=HIDDEN_DIM,
                       num_layers=NUM_LAYERS)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()

    model.train()

    # criterion and optimizer
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    # metric
    loss_meter = meter.AverageValueMeter()
    previous_loss = 100000
    previous_acc = 0

    # train
    if not os.path.exists(os.path.join('checkpoints', model.model_name)):
        os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        loss_meter.reset()
        train_cm = [[0] * 3, [0] * 3, [0] * 3]
        count = 0

        # train
        for i, (features, labels,
                feature_paths) in tqdm(enumerate(train_dataloader)):
            # prepare input
            target = torch.LongTensor([tag_to_ix[t[0]] for t in labels])

            feat = Variable(features.squeeze())
            # target = Variable(target)
            if config.use_gpu:
                feat = feat.cuda()
                # target = target.cuda()

            model.zero_grad()

            try:
                neg_log_likelihood = model.neg_log_likelihood(feat, target)
            except NameError:
                count += 1
                continue

            neg_log_likelihood.backward()
            optimizer.step()

            loss_meter.add(neg_log_likelihood.data[0])
            result = model(feat)
            for t, r in zip(target, result[1]):
                train_cm[t][r] += 1

            if i % config.print_freq == config.print_freq - 1:
                vis.plot('loss', loss_meter.value()[0])
                print('loss', loss_meter.value()[0])

        train_accuracy = 100. * sum(
            [train_cm[c][c]
             for c in range(config.num_classes)]) / np.sum(train_cm)
        val_cm, val_accuracy, val_loss = val(model, val_dataloader)

        if val_accuracy > previous_acc:
            if config.save_model_name:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 config.save_model_name))
            else:
                model.save(
                    os.path.join('checkpoints', model.model_name,
                                 model.model_name + '_best_model.pth'))
            previous_acc = val_accuracy

        vis.plot_many({
            'train_accuracy': train_accuracy,
            'val_accuracy': val_accuracy
        })
        vis.log(
            "epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
                epoch=epoch + 1,
                total_epoch=config.max_epoch,
                lr=lr,
                loss=loss_meter.value()[0]))
        vis.log('train_cm:')
        vis.log(train_cm)
        vis.log('val_cm')
        vis.log(val_cm)
        print('train_accuracy:', train_accuracy, 'val_accuracy:', val_accuracy)
        print("epoch: [{epoch}/{total_epoch}], lr: {lr}, loss: {loss}".format(
            epoch=epoch + 1,
            total_epoch=config.max_epoch,
            lr=lr,
            loss=loss_meter.value()[0]))
        print('train_cm:')
        print(train_cm)
        print('val_cm:')
        print(val_cm)
        print('Num of NameError:', count)

        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter.value()[0]
示例#5
0
def train(**kwargs):
    opt.parse(kwargs)

    if opt.vis_env:
        vis = Visualizer(opt.vis_env, port=opt.vis_port)

    if opt.device is None or opt.device is 'cpu':
        opt.device = torch.device('cpu')
    else:
        opt.device = torch.device(opt.device)

    images, tags, labels = load_data(opt.data_path, type=opt.dataset)

    train_data = Dataset(opt, images, tags, labels)
    train_dataloader = DataLoader(train_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True)

    # valid or test data
    x_query_data = Dataset(opt, images, tags, labels, test='image.query')
    x_db_data = Dataset(opt, images, tags, labels, test='image.db')
    y_query_data = Dataset(opt, images, tags, labels, test='text.query')
    y_db_data = Dataset(opt, images, tags, labels, test='text.db')

    x_query_dataloader = DataLoader(x_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    x_db_dataloader = DataLoader(x_db_data, opt.batch_size, shuffle=False)
    y_query_dataloader = DataLoader(y_query_data,
                                    opt.batch_size,
                                    shuffle=False)
    y_db_dataloader = DataLoader(y_db_data, opt.batch_size, shuffle=False)

    query_labels, db_labels = x_query_data.get_labels()
    query_labels = query_labels.to(opt.device)
    db_labels = db_labels.to(opt.device)

    if opt.load_model_path:
        pretrain_model = None
    elif opt.pretrain_model_path:
        pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    model = AGAH(opt.bit,
                 opt.tag_dim,
                 opt.num_label,
                 opt.emb_dim,
                 lambd=opt.lambd,
                 pretrain_model=pretrain_model).to(opt.device)

    load_model(model, opt.load_model_path)

    optimizer = Adamax([{
        'params': model.img_module.parameters(),
        'lr': opt.lr
    }, {
        'params': model.txt_module.parameters()
    }, {
        'params': model.hash_module.parameters()
    }, {
        'params': model.classifier.parameters()
    }],
                       lr=opt.lr * 10,
                       weight_decay=0.0005)

    optimizer_dis = {
        'img':
        Adamax(model.img_discriminator.parameters(),
               lr=opt.lr * 10,
               betas=(0.5, 0.9),
               weight_decay=0.0001),
        'txt':
        Adamax(model.txt_discriminator.parameters(),
               lr=opt.lr * 10,
               betas=(0.5, 0.9),
               weight_decay=0.0001)
    }

    criterion_tri_cos = TripletAllLoss(dis_metric='cos', reduction='sum')
    criterion_bce = nn.BCELoss(reduction='sum')

    loss = []

    max_mapi2t = 0.
    max_mapt2i = 0.

    FEATURE_I = torch.randn(opt.training_size, opt.emb_dim).to(opt.device)
    FEATURE_T = torch.randn(opt.training_size, opt.emb_dim).to(opt.device)

    U = torch.randn(opt.training_size, opt.bit).to(opt.device)
    V = torch.randn(opt.training_size, opt.bit).to(opt.device)

    FEATURE_MAP = torch.randn(opt.num_label, opt.emb_dim).to(opt.device)
    CODE_MAP = torch.sign(torch.randn(opt.num_label, opt.bit)).to(opt.device)

    train_labels = train_data.get_labels().to(opt.device)

    mapt2i_list = []
    mapi2t_list = []
    train_times = []

    for epoch in range(opt.max_epoch):
        t1 = time.time()
        for i, (ind, x, y, l) in tqdm(enumerate(train_dataloader)):
            imgs = x.to(opt.device)
            tags = y.to(opt.device)
            labels = l.to(opt.device)

            batch_size = len(ind)

            h_x, h_y, f_x, f_y, x_class, y_class = model(
                imgs, tags, FEATURE_MAP)

            FEATURE_I[ind] = f_x.data
            FEATURE_T[ind] = f_y.data
            U[ind] = h_x.data
            V[ind] = h_y.data

            #####
            # train txt discriminator
            #####
            D_txt_real = model.dis_txt(f_y.detach())
            D_txt_real = -D_txt_real.mean()
            optimizer_dis['txt'].zero_grad()
            D_txt_real.backward()

            # train with fake
            D_txt_fake = model.dis_txt(f_x.detach())
            D_txt_fake = D_txt_fake.mean()
            D_txt_fake.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device)
            interpolates = alpha * f_y.detach() + (1 - alpha) * f_x.detach()
            interpolates.requires_grad_()
            disc_interpolates = model.dis_txt(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates,
                                      inputs=interpolates,
                                      grad_outputs=torch.ones(
                                          disc_interpolates.size()).to(
                                              opt.device),
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            txt_gradient_penalty = (
                (gradients.norm(2, dim=1) - 1)**2).mean() * 10
            txt_gradient_penalty.backward()

            loss_D_txt = D_txt_real - D_txt_fake
            optimizer_dis['txt'].step()

            #####
            # train img discriminator
            #####
            D_img_real = model.dis_img(f_x.detach())
            D_img_real = -D_img_real.mean()
            optimizer_dis['img'].zero_grad()
            D_img_real.backward()

            # train with fake
            D_img_fake = model.dis_img(f_y.detach())
            D_img_fake = D_img_fake.mean()
            D_img_fake.backward()

            # train with gradient penalty
            alpha = torch.rand(batch_size, opt.emb_dim).to(opt.device)
            interpolates = alpha * f_x.detach() + (1 - alpha) * f_y.detach()
            interpolates.requires_grad_()
            disc_interpolates = model.dis_img(interpolates)
            gradients = autograd.grad(outputs=disc_interpolates,
                                      inputs=interpolates,
                                      grad_outputs=torch.ones(
                                          disc_interpolates.size()).to(
                                              opt.device),
                                      create_graph=True,
                                      retain_graph=True,
                                      only_inputs=True)[0]
            gradients = gradients.view(gradients.size(0), -1)
            # 10 is gradient penalty hyperparameter
            img_gradient_penalty = (
                (gradients.norm(2, dim=1) - 1)**2).mean() * 10
            img_gradient_penalty.backward()

            loss_D_img = D_img_real - D_img_fake
            optimizer_dis['img'].step()

            #####
            # train generators
            #####
            # update img network (to generate txt features)
            domain_output = model.dis_txt(f_x)
            loss_G_txt = -domain_output.mean()

            # update txt network (to generate img features)
            domain_output = model.dis_img(f_y)
            loss_G_img = -domain_output.mean()

            loss_adver = loss_G_txt + loss_G_img

            loss1 = criterion_tri_cos(h_x,
                                      labels,
                                      target=h_y,
                                      margin=opt.margin)
            loss2 = criterion_tri_cos(h_y,
                                      labels,
                                      target=h_x,
                                      margin=opt.margin)

            theta1 = F.cosine_similarity(torch.abs(h_x),
                                         torch.ones_like(h_x).to(opt.device))
            theta2 = F.cosine_similarity(torch.abs(h_y),
                                         torch.ones_like(h_y).to(opt.device))
            loss3 = torch.sum(1 / (1 + torch.exp(theta1))) + torch.sum(
                1 / (1 + torch.exp(theta2)))

            loss_class = criterion_bce(x_class, labels) + criterion_bce(
                y_class, labels)

            theta_code_x = h_x.mm(CODE_MAP.t())  # size: (batch, num_label)
            theta_code_y = h_y.mm(CODE_MAP.t())
            loss_code_map = torch.sum(torch.pow(theta_code_x - opt.bit * (labels * 2 - 1), 2)) + \
                            torch.sum(torch.pow(theta_code_y - opt.bit * (labels * 2 - 1), 2))

            loss_quant = torch.sum(torch.pow(
                h_x - torch.sign(h_x), 2)) + torch.sum(
                    torch.pow(h_y - torch.sign(h_y), 2))

            # err = loss1 + loss2 + loss3 + 0.5 * loss_class + 0.5 * (loss_f1 + loss_f2)
            err = loss1 + loss2 + opt.alpha * loss3 + opt.beta * loss_class + opt.gamma * loss_code_map + \
                  opt.eta * loss_quant + opt.mu * loss_adver

            optimizer.zero_grad()
            err.backward()
            optimizer.step()

            loss.append(err.item())

        CODE_MAP = update_code_map(U, V, CODE_MAP, train_labels)
        FEATURE_MAP = update_feature_map(FEATURE_I, FEATURE_T, train_labels)

        print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1]))
        delta_t = time.time() - t1

        if opt.vis_env:
            vis.plot('loss', loss[-1])

        # validate
        if opt.valid and (epoch + 1) % opt.valid_freq == 0:
            mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader,
                                   y_query_dataloader, y_db_dataloader,
                                   query_labels, db_labels, FEATURE_MAP)
            print(
                '...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f'
                % (epoch + 1, mapi2t, mapt2i))

            mapi2t_list.append(mapi2t)
            mapt2i_list.append(mapt2i)
            train_times.append(delta_t)

            if opt.vis_env:
                d = {'mapi2t': mapi2t, 'mapt2i': mapt2i}
                vis.plot_many(d)

            if mapt2i >= max_mapt2i and mapi2t >= max_mapi2t:
                max_mapi2t = mapi2t
                max_mapt2i = mapt2i
                save_model(model)
                path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
                with torch.cuda.device(opt.device):
                    torch.save(FEATURE_MAP,
                               os.path.join(path, 'feature_map.pth'))

        if epoch % 100 == 0:
            for params in optimizer.param_groups:
                params['lr'] = max(params['lr'] * 0.6, 1e-6)

    if not opt.valid:
        save_model(model)

    print('...training procedure finish')
    if opt.valid:
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
              (max_mapi2t, max_mapt2i))
    else:
        mapi2t, mapt2i = valid(model, x_query_dataloader, x_db_dataloader,
                               y_query_dataloader, y_db_dataloader,
                               query_labels, db_labels, FEATURE_MAP)
        print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
              (mapi2t, mapt2i))

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    with open(os.path.join(path, 'result.pkl'), 'wb') as f:
        pickle.dump([train_times, mapi2t_list, mapt2i_list], f)
示例#6
0
def train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data = SlideWindowDataset(config.train_paths,
                                    phase='train',
                                    useRGB=config.useRGB,
                                    usetrans=config.usetrans,
                                    balance=config.data_balance)
    val_data = SlideWindowDataset(config.test_paths,
                                  phase='val',
                                  useRGB=config.useRGB,
                                  usetrans=config.usetrans,
                                  balance=False)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist)

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    # model = AlexNet(num_classes=config.num_classes)
    # model = Vgg16(num_classes=config.num_classes)
    # model = Modified_Vgg16(num_classes=config.num_classes)
    # model = ResNet18(num_classes=config.num_classes)
    model = ResNet50(num_classes=config.num_classes)
    # model = DenseNet121(num_classes=config.num_classes)
    # model = ShallowNet(num_classes=config.num_classes)
    # model = Customed_ShallowNet(num_classes=config.num_classes)

    # model = Modified_AGVgg16(num_classes=2)
    # model = AGResNet18(num_classes=2)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1, 1])
    # weight = torch.Tensor([dist['1']/(dist['0']+dist['1']), dist['0']/(dist['0']+dist['1'])])  # weight需要将二者反过来,多于二分类可以取倒数
    # weight = torch.Tensor([1, 3.5])
    # weight = torch.Tensor([1, 5])
    weight = torch.Tensor([1, 7])

    vis.log(f'loss weight: {weight}')
    print('loss weight:', weight)
    weight = weight.cuda()
    criterion = torch.nn.CrossEntropyLoss(weight=weight)
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    softmax = functional.softmax
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)

    # ====================================== Saving and Recording Configuration =================================
    previous_auc = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    process_record = {
        'epoch_loss': [],
        'train_avg_se': [],
        'train_se_0': [],
        'train_se_1': [],
        'val_avg_se': [],
        'val_se_0': [],
        'val_se_1': [],
        'AUC': []
    }  # 用于记录实验过程中的曲线,便于画曲线图

    # ================================================== Training ===============================================
    for epoch in range(config.max_epoch):
        print(
            f"epoch: [{epoch+1}/{config.max_epoch}] {config.save_model_name[:-4]} =================================="
        )
        train_cm.reset()
        epoch_loss.reset()

        # ****************************************** train ****************************************
        model.train()
        for i, (image, label, image_path) in tqdm(enumerate(train_dataloader)):
            loss_meter.reset()

            # ------------------------------------ prepare input ------------------------------------
            if config.use_gpu:
                image = image.cuda()
                label = label.cuda()

            # ---------------------------------- go through the model --------------------------------
            score = model(image)

            # ----------------------------------- backpropagate -------------------------------------
            optimizer.zero_grad()
            loss = criterion(score, label)
            loss.backward()
            optimizer.step()

            # ------------------------------------ record loss ------------------------------------
            loss_meter.add(loss.item())
            epoch_loss.add(loss.item())
            train_cm.add(softmax(score, dim=1).detach(), label.detach())

            if (i + 1) % config.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])

        train_se = [
            100. * train_cm.value()[0][0] /
            (train_cm.value()[0][0] + train_cm.value()[0][1]),
            100. * train_cm.value()[1][1] /
            (train_cm.value()[1][0] + train_cm.value()[1][1])
        ]

        # *************************************** validate ***************************************
        model.eval()
        if (epoch + 1) % 1 == 0:
            Best_T, val_cm, val_spse, val_accuracy, AUC = val(
                model, val_dataloader)

            # ------------------------------------ save model ------------------------------------
            if AUC > previous_auc and epoch + 1 > 5:
                if config.parallel:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.module.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                else:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                previous_auc = AUC
                save_epoch = epoch + 1

            # ---------------------------------- recond and print ---------------------------------
            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['train_avg_se'].append(np.average(train_se))
            process_record['train_se_0'].append(train_se[0])
            process_record['train_se_1'].append(train_se[1])
            process_record['val_avg_se'].append(np.average(val_spse))
            process_record['val_se_0'].append(val_spse[0])
            process_record['val_se_1'].append(val_spse[1])
            process_record['AUC'].append(AUC)

            vis.plot_many({
                'epoch_loss': epoch_loss.value()[0],
                'train_avg_se': np.average(train_se),
                'train_se_0': train_se[0],
                'train_se_1': train_se[1],
                'val_avg_se': np.average(val_spse),
                'val_se_0': val_spse[0],
                'val_se_1': val_spse[1],
                'AUC': AUC
            })
            vis.log(
                f"epoch: [{epoch+1}/{config.max_epoch}] ========================================="
            )
            vis.log(
                f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}"
            )
            vis.log(
                f"train_avg_se: {round(np.average(train_se), 4)}, train_se_0: {round(train_se[0], 4)}, train_se_1: {round(train_se[1], 4)}"
            )
            vis.log(
                f"val_avg_se: {round(sum(val_spse)/len(val_spse), 4)}, val_se_0: {round(val_spse[0], 4)}, val_se_1: {round(val_spse[1], 4)}"
            )
            vis.log(f"AUC: {AUC}")
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'Best Threshold: {Best_T}')
            vis.log(f'val_cm: {val_cm}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                  round(epoch_loss.value()[0], 5))
            print('train_avg_se:', round(np.average(train_se), 4),
                  'train_se_0:', round(train_se[0], 4), 'train_se_1:',
                  round(train_se[1], 4))
            print('val_avg_se:', round(np.average(val_spse), 4), 'val_se_0:',
                  round(val_spse[0], 4), 'val_se_1:', round(val_spse[1], 4))
            print('AUC:', AUC)
            print('train_cm:')
            print(train_cm.value())
            print('Best Threshold:', Best_T, 'val_cm:')
            print(val_cm)

            # ------------------------------------ save record ------------------------------------
            if os.path.exists(
                    os.path.join('checkpoints', save_model_dir,
                                 save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir,
                                             save_model_name.split('.')[0],
                                             'process_record.json'),
                           content=process_record)

        # if (epoch+1) % 5 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)
示例#7
0
def train(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # prepare data
    train_data = VB_Dataset(config.train_paths,
                            phase='train',
                            useRGB=config.useRGB,
                            usetrans=config.usetrans,
                            padding=config.padding,
                            balance=config.data_balance)
    val_data = VB_Dataset(config.test_paths,
                          phase='val',
                          useRGB=config.useRGB,
                          usetrans=config.usetrans,
                          padding=config.padding,
                          balance=False)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist, 'Val Data Distribution:',
          val_data.dist())

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # prepare model
    # model = ResNet18(num_classes=config.num_classes)
    # model = Vgg16(num_classes=config.num_classes)
    # model = densenet_collapse(num_classes=config.num_classes)
    model = ShallowVgg(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    # criterion and optimizer
    # weight = torch.Tensor([1/dist['0'], 1/dist['1'], 1/dist['2'], 1/dist['3']])
    # weight = torch.Tensor([1/dist['0'], 1/dist['1']])
    # weight = torch.Tensor([dist['1'], dist['0']])
    # weight = torch.Tensor([1, 10])
    # vis.log(f'loss weight: {weight}')
    # print('loss weight:', weight)
    # weight = weight.cuda()

    # criterion = torch.nn.CrossEntropyLoss()
    criterion = LabelSmoothing(size=config.num_classes, smoothing=0.1)
    # criterion = torch.nn.CrossEntropyLoss(weight=weight)
    # criterion = FocalLoss(gamma=4, alpha=None)

    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # metric
    softmax = functional.softmax
    log_softmax = functional.log_softmax
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)
    train_AUC = meter.AUCMeter()

    previous_avgse = 0
    # previous_AUC = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    # process_record = {'epoch_loss': [],  # 用于记录实验过程中的曲线,便于画曲线图
    #                   'train_avgse': [], 'train_se0': [], 'train_se1': [], 'train_se2': [], 'train_se3': [],
    #                   'val_avgse': [], 'val_se0': [], 'val_se1': [], 'val_se2': [], 'val_se3': []}
    process_record = {
        'epoch_loss': [],  # 用于记录实验过程中的曲线,便于画曲线图
        'train_avgse': [],
        'train_se0': [],
        'train_se1': [],
        'val_avgse': [],
        'val_se0': [],
        'val_se1': [],
        'train_AUC': [],
        'val_AUC': []
    }

    # train
    for epoch in range(config.max_epoch):
        print(
            f"epoch: [{epoch+1}/{config.max_epoch}] {config.save_model_name[:-4]} =================================="
        )
        epoch_loss.reset()
        train_cm.reset()
        train_AUC.reset()

        # train
        model.train()
        for i, (image, label, image_path) in tqdm(enumerate(train_dataloader)):
            loss_meter.reset()

            # prepare input
            if config.use_gpu:
                image = image.cuda()
                label = label.cuda()

            # go through the model
            score = model(image)

            # backpropagate
            optimizer.zero_grad()
            # loss = criterion(score, label)
            loss = criterion(log_softmax(score, dim=1), label)
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.item())
            epoch_loss.add(loss.item())
            train_cm.add(softmax(score, dim=1).data, label.data)
            positive_score = np.array([
                item[1]
                for item in softmax(score, dim=1).data.cpu().numpy().tolist()
            ])
            train_AUC.add(positive_score, label.data)

            if (i + 1) % config.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])

        # print result
        # train_se = [100. * train_cm.value()[0][0] / (train_cm.value()[0][0] + train_cm.value()[0][1] + train_cm.value()[0][2] + train_cm.value()[0][3]),
        #             100. * train_cm.value()[1][1] / (train_cm.value()[1][0] + train_cm.value()[1][1] + train_cm.value()[1][2] + train_cm.value()[1][3]),
        #             100. * train_cm.value()[2][2] / (train_cm.value()[2][0] + train_cm.value()[2][1] + train_cm.value()[2][2] + train_cm.value()[2][3]),
        #             100. * train_cm.value()[3][3] / (train_cm.value()[3][0] + train_cm.value()[3][1] + train_cm.value()[3][2] + train_cm.value()[3][3])]
        train_se = [
            100. * train_cm.value()[0][0] /
            (train_cm.value()[0][0] + train_cm.value()[0][1]),
            100. * train_cm.value()[1][1] /
            (train_cm.value()[1][0] + train_cm.value()[1][1])
        ]

        # validate
        model.eval()
        if (epoch + 1) % 1 == 0:
            val_cm, val_se, val_accuracy, val_AUC = val_2class(
                model, val_dataloader)

            if np.average(
                    val_se) > previous_avgse:  # 当测试集上的平均sensitivity升高时保存模型
                # if val_AUC.value()[0] > previous_AUC:  # 当测试集上的AUC升高时保存模型
                if config.parallel:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.module.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                else:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0]))
                    model.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name.split('.')[0],
                                     save_model_name))
                previous_avgse = np.average(val_se)
                # previous_AUC = val_AUC.value()[0]
                save_epoch = epoch + 1

            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['train_avgse'].append(np.average(train_se))
            process_record['train_se0'].append(train_se[0])
            process_record['train_se1'].append(train_se[1])
            # process_record['train_se2'].append(train_se[2])
            # process_record['train_se3'].append(train_se[3])
            process_record['train_AUC'].append(train_AUC.value()[0])
            process_record['val_avgse'].append(np.average(val_se))
            process_record['val_se0'].append(val_se[0])
            process_record['val_se1'].append(val_se[1])
            # process_record['val_se2'].append(val_se[2])
            # process_record['val_se3'].append(val_se[3])
            process_record['val_AUC'].append(val_AUC.value()[0])

            # vis.plot_many({'epoch_loss': epoch_loss.value()[0],
            #                'train_avgse': np.average(train_se), 'train_se0': train_se[0], 'train_se1': train_se[1], 'train_se2': train_se[2], 'train_se3': train_se[3],
            #                'val_avgse': np.average(val_se), 'val_se0': val_se[0], 'val_se1': val_se[1], 'val_se2': val_se[2], 'val_se3': val_se[3]})
            # vis.log(f"epoch: [{epoch+1}/{config.max_epoch}] =========================================")
            # vis.log(f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}")
            # vis.log(f"train_avgse: {round(np.average(train_se), 4)}, train_se0: {round(train_se[0], 4)}, train_se1: {round(train_se[1], 4)}, train_se2: {round(train_se[2], 4)}, train_se3: {round(train_se[3], 4)},")
            # vis.log(f"val_avgse: {round(np.average(val_se), 4)}, val_se0: {round(val_se[0], 4)}, val_se1: {round(val_se[1], 4)}, val_se2: {round(val_se[2], 4)}, val_se3: {round(val_se[3], 4)}")
            # vis.log(f'train_cm: {train_cm.value()}')
            # vis.log(f'val_cm: {val_cm.value()}')
            # print("lr:", optimizer.param_groups[0]['lr'], "loss:", round(epoch_loss.value()[0], 5))
            # print('train_avgse:', round(np.average(train_se), 4), 'train_se0:', round(train_se[0], 4), 'train_se1:', round(train_se[1], 4), 'train_se2:', round(train_se[2], 4), 'train_se3:', round(train_se[3], 4))
            # print('val_avgse:', round(np.average(val_se), 4), 'val_se0:', round(val_se[0], 4), 'val_se1:', round(val_se[1], 4), 'val_se2:', round(val_se[2], 4), 'val_se3:', round(val_se[3], 4))
            # print('train_cm:')
            # print(train_cm.value())
            # print('val_cm:')
            # print(val_cm.value())

            vis.plot_many({
                'epoch_loss': epoch_loss.value()[0],
                'train_avgse': np.average(train_se),
                'train_se0': train_se[0],
                'train_se1': train_se[1],
                'val_avgse': np.average(val_se),
                'val_se0': val_se[0],
                'val_se1': val_se[1],
                'train_AUC': train_AUC.value()[0],
                'val_AUC': val_AUC.value()[0]
            })
            vis.log(
                f"epoch: [{epoch + 1}/{config.max_epoch}] ========================================="
            )
            vis.log(
                f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}"
            )
            vis.log(
                f"train_avgse: {round(np.average(train_se), 4)}, train_se0: {round(train_se[0], 4)}, train_se1: {round(train_se[1], 4)}"
            )
            vis.log(
                f"val_avgse: {round(np.average(val_se), 4)}, val_se0: {round(val_se[0], 4)}, val_se1: {round(val_se[1], 4)}"
            )
            vis.log(f'train_AUC: {train_AUC.value()[0]}')
            vis.log(f'val_AUC: {val_AUC.value()[0]}')
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'val_cm: {val_cm.value()}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                  round(epoch_loss.value()[0], 5))
            print('train_avgse:', round(np.average(train_se), 4), 'train_se0:',
                  round(train_se[0], 4), 'train_se1:', round(train_se[1], 4))
            print('val_avgse:', round(np.average(val_se), 4), 'val_se0:',
                  round(val_se[0], 4), 'val_se1:', round(val_se[1], 4))
            print('train_AUC:',
                  train_AUC.value()[0], 'val_AUC:',
                  val_AUC.value()[0])
            print('train_cm:')
            print(train_cm.value())
            print('val_cm:')
            print(val_cm.value())

            if os.path.exists(
                    os.path.join('checkpoints', save_model_dir,
                                 save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir,
                                             save_model_name.split('.')[0],
                                             'process_record.json'),
                           content=process_record)

        # if (epoch+1) % 5 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)
示例#8
0
def train_pair(**kwargs):
    config.parse(kwargs)
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # prepare data
    train_data = PairSWDataset(config.train_paths, phase='train', useRGB=config.useRGB, usetrans=config.usetrans, balance=config.data_balance)
    valpair_data = PairSWDataset(config.test_paths, phase='val_pair', useRGB=config.useRGB, usetrans=config.usetrans, balance=False)
    print('Training Samples:', train_data.__len__(), 'ValPair Samples:', valpair_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist)

    train_dataloader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
    valpair_dataloader = DataLoader(valpair_data, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

    # prepare model
    model = SiameseNet(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(model, device_ids=[x for x in range(config.num_of_gpu)])

    model.train()

    # criterion and optimizer
    weight_pair = torch.Tensor([1, 1.5])
    vis.log(f'pair loss weight: {weight_pair}')
    print('pair loss weight:', weight_pair)
    weight_pair = weight_pair.cuda()
    pair_criterion = torch.nn.CrossEntropyLoss(weight=weight_pair)

    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.weight_decay)

    # metric
    softmax = functional.softmax
    pair_loss_meter = meter.AverageValueMeter()
    pair_epoch_loss = meter.AverageValueMeter()

    pair_train_cm = meter.ConfusionMeter(config.num_classes)
    # previous_loss = 100
    pair_previous_avg_se = 0

    # train
    if config.parallel:
        if not os.path.exists(os.path.join('checkpoints', model.module.model_name)):
            os.mkdir(os.path.join('checkpoints', model.module.model_name))
    else:
        if not os.path.exists(os.path.join('checkpoints', model.model_name)):
            os.mkdir(os.path.join('checkpoints', model.model_name))

    for epoch in range(config.max_epoch):
        print(f"epoch: [{epoch+1}/{config.max_epoch}] =============================================")
        pair_train_cm.reset()
        pair_epoch_loss.reset()

        # train
        for i, (image_1, image_2, label_1, label_2, label_res, _, _) in tqdm(enumerate(train_dataloader)):
            pair_loss_meter.reset()

            # prepare input
            image_1 = Variable(image_1)
            image_2 = Variable(image_2)
            target_res = Variable(label_res)

            if config.use_gpu:
                image_1 = image_1.cuda()
                image_2 = image_2.cuda()
                target_res = target_res.cuda()

            # go through the model
            score_1, score_2, score_res = model(image_1, image_2)

            # backpropagate
            optimizer.zero_grad()
            pair_loss = pair_criterion(score_res, target_res)
            pair_loss.backward()
            optimizer.step()

            pair_loss_meter.add(pair_loss.data[0])
            pair_epoch_loss.add(pair_loss.data[0])

            pair_train_cm.add(softmax(score_res, dim=1).data, target_res.data)

            if (i+1) % config.print_freq == 0:
                vis.plot('loss', pair_loss_meter.value()[0])

        # print result
        pair_train_se = [100. * pair_train_cm.value()[0][0] / (pair_train_cm.value()[0][0] + pair_train_cm.value()[0][1]),
                         100. * pair_train_cm.value()[1][1] / (pair_train_cm.value()[1][0] + pair_train_cm.value()[1][1])]
        model.eval()
        pair_val_cm, pair_val_accuracy, pair_val_se = val_pair(model, valpair_dataloader)

        if np.average(pair_val_se) > pair_previous_avg_se:  # 当测试集上的平均sensitivity升高时保存模型
            if config.parallel:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
                save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
                if not os.path.exists(os.path.join('checkpoints', save_model_dir)):
                    os.makedirs(os.path.join('checkpoints', save_model_dir))
                model.module.save(os.path.join('checkpoints', save_model_dir, save_model_name))
            else:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
                save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
                if not os.path.exists(os.path.join('checkpoints', save_model_dir)):
                    os.makedirs(os.path.join('checkpoints', save_model_dir))
                model.save(os.path.join('checkpoints', save_model_dir, save_model_name))
            pair_previous_avg_se = np.average(pair_val_se)

        if epoch+1 == config.max_epoch:  # 保存最后一个模型
            if config.parallel:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
                save_model_name = config.save_model_name.split('.pth')[0]+'_last.pth' if config.save_model_name else model.module.model_name + '_last_model.pth'
            else:
                save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
                save_model_name = config.save_model_name.split('.pth')[0]+'_last.pth' if config.save_model_name else model.model_name + '_last_model.pth'
            if not os.path.exists(os.path.join('checkpoints', save_model_dir)):
                os.makedirs(os.path.join('checkpoints', save_model_dir))
            model.save(os.path.join('checkpoints', save_model_dir, save_model_name))

        vis.plot_many({'epoch_loss': pair_epoch_loss.value()[0],
                       'pair_train_avg_se': np.average(pair_train_se), 'pair_train_se_0': pair_train_se[0], 'pair_train_se_1': pair_train_se[1],
                       'pair_val_avg_se': np.average(pair_val_se), 'pair_val_se_0': pair_val_se[0], 'pair_val_se_1': pair_val_se[1]})
        vis.log(f"epoch: [{epoch+1}/{config.max_epoch}] ===============================================")
        vis.log(f"lr: {lr}, loss: {round(pair_epoch_loss.value()[0], 5)}")
        vis.log(f"pair_train_avg_se: {round(np.average(pair_train_se), 4)}, pair_train_se_0: {round(pair_train_se[0], 4)}, pair_train_se_1: {round(pair_train_se[1], 4)}")
        vis.log(f"pair_val_avg_se: {round(sum(pair_val_se) / len(pair_val_se), 4)}, pair_val_se_0: {round(pair_val_se[0], 4)}, pair_val_se_1: {round(pair_val_se[1], 4)}")
        vis.log(f'pair_train_cm: {pair_train_cm.value()}')
        vis.log(f'pair_val_cm: {pair_val_cm.value()}')
        print("lr:", lr, "loss:", round(pair_epoch_loss.value()[0], 5))
        print('pair_train_avg_se:', round(np.average(pair_train_se), 4), 'pair_train_se_0:', round(pair_train_se[0], 4), 'pair_train_se_1:', round(pair_train_se[1], 4))
        print('pair_val_avg_se:', round(np.average(pair_val_se), 4), 'pair_val_se_0:', round(pair_val_se[0], 4), 'pair_val_se_1:', round(pair_val_se[1], 4))
        print('pair_train_cm:')
        print(pair_train_cm.value())
        print('pair_val_cm:')
        print(pair_val_cm.value())

        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        # previous_loss = loss_meter.value()[0]
        if (epoch+1) % 5 == 0:
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
示例#9
0
def train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data_1 = SlideWindowDataset(config.train_paths, phase='train', useRGB=config.useRGB, usetrans=config.usetrans, balance=config.data_balance)
    train_data_2 = SlideWindowDataset(config.train_paths, phase='train', useRGB=config.useRGB, usetrans=config.usetrans, balance=config.data_balance)
    val_data = SlideWindowDataset(config.test_paths, phase='val', useRGB=config.useRGB, usetrans=config.usetrans, balance=False)
    print('Training Images:', train_data_1.__len__(), 'Validation Images:', val_data.__len__())
    dist = train_data_1.dist()
    print('Train Data Distribution:', dist)

    train_dataloader_1 = DataLoader(train_data_1, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
    train_dataloader_2 = DataLoader(train_data_2, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    # model = PCResNet18(num_classes=config.num_classes)
    model = DualResNet18(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(model, device_ids=[x for x in range(config.num_of_gpu)])

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1, 1])
    # weight = torch.Tensor([dist['1']/(dist['0']+dist['1']), dist['0']/(dist['0']+dist['1'])])  # weight需要将二者反过来,多于二分类可以取倒数
    # weight = torch.Tensor([1, 3.5])
    # weight = torch.Tensor([1, 5])
    weight = torch.Tensor([1, 7])
    vis.log(f'loss weight: {weight}')
    print('loss weight:', weight)
    weight = weight.cuda()

    criterion = torch.nn.CrossEntropyLoss(weight=weight)
    MSELoss = torch.nn.MSELoss()
    sycriterion = torch.nn.CrossEntropyLoss()

    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    softmax = functional.softmax
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    mse_meter = meter.AverageValueMeter()
    epoch_mse = meter.AverageValueMeter()
    syloss_meter = meter.AverageValueMeter()
    epoch_syloss = meter.AverageValueMeter()
    total_loss_meter = meter.AverageValueMeter()
    epoch_total_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)

    # ====================================== Saving and Recording Configuration =================================
    previous_auc = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    process_record = {'epoch_loss': [],
                      'train_avg_se': [], 'train_se_0': [], 'train_se_1': [],
                      'val_avg_se': [], 'val_se_0': [], 'val_se_1': [],
                      'AUC': []}  # 用于记录实验过程中的曲线,便于画曲线图

    # ================================================== Training ===============================================
    for epoch in range(config.max_epoch):
        print(f"epoch: [{epoch+1}/{config.max_epoch}] {config.save_model_name[:-4]} ==================================")
        train_cm.reset()
        epoch_loss.reset()
        epoch_mse.reset()
        epoch_syloss.reset()
        epoch_total_loss.reset()

        # ****************************************** train ****************************************
        model.train()
        for i, (item1, item2) in tqdm(enumerate(zip(train_dataloader_1, train_dataloader_2))):
            loss_meter.reset()
            mse_meter.reset()
            syloss_meter.reset()
            total_loss_meter.reset()

            # ------------------------------------ prepare input ------------------------------------
            image1, label1, image_path1 = item1
            image2, label2, image_path2 = item2
            if config.use_gpu:
                image1 = image1.cuda()
                image2 = image2.cuda()
                label1 = label1.cuda()
                label2 = label2.cuda()

            # ---------------------------------- go through the model --------------------------------
            # score1, score2, logits1, logits2 = model(image1, image2)  # Pairwise Confusion Network
            score1, score2, score3 = model(image1, image2)  # Dual CNN

            # ----------------------------------- backpropagate -------------------------------------
            # 两支之间的feature加入L2 norm
            # optimizer.zero_grad()
            # cls_loss1 = criterion(score1, label1)
            # cls_loss2 = criterion(score2, label2)
            #
            # ch_weight = torch.where(label1 == label2, torch.Tensor([0]).cuda(), torch.Tensor([1]).cuda())
            # ch_weight = ch_weight.view(logits1.size(0), -1)
            # mse = MSELoss(logits1 * ch_weight, logits2 * ch_weight)  # 只计算不同类之间的loss,相同类的置零
            #
            # total_loss = cls_loss1 + cls_loss2 + 10 * mse
            # total_loss.backward()
            # optimizer.step()

            # 两支之间的logits加入判断是否属于同一类的loss
            optimizer.zero_grad()
            cls_loss1 = criterion(score1, label1)
            cls_loss2 = criterion(score2, label2)

            sylabel = torch.where(label1 == label2, torch.Tensor([0]).cuda(), torch.Tensor([1]).cuda()).long()
            sy_loss = sycriterion(score3, sylabel)

            total_loss = cls_loss1 + cls_loss2 + 2 * sy_loss
            total_loss.backward()
            optimizer.step()

            # ------------------------------------ record loss ------------------------------------
            loss_meter.add((cls_loss1 + cls_loss2).item())
            # mse_meter.add(mse.item())
            # syloss_meter.add(sy_loss.item())
            # total_loss_meter.add(total_loss.item())

            epoch_loss.add((cls_loss1 + cls_loss2).item())
            # epoch_mse.add(mse.item())
            epoch_syloss.add(sy_loss.item())
            epoch_total_loss.add(total_loss.item())

            train_cm.add(softmax(score1, dim=1).detach(), label1.detach())

            if (i+1) % config.print_freq == 0:
                vis.plot('loss', loss_meter.value()[0])

        train_se = [100. * train_cm.value()[0][0] / (train_cm.value()[0][0] + train_cm.value()[0][1]),
                    100. * train_cm.value()[1][1] / (train_cm.value()[1][0] + train_cm.value()[1][1])]

        # *************************************** validate ***************************************
        model.eval()
        if (epoch + 1) % 1 == 0:
            Best_T, val_cm, val_spse, val_accuracy, AUC = val(model, val_dataloader)

            # ------------------------------------ save model ------------------------------------
            if AUC > previous_auc and epoch + 1 > 5:
                if config.parallel:
                    if not os.path.exists(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0])):
                        os.makedirs(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0]))
                    model.module.save(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0], save_model_name))
                else:
                    if not os.path.exists(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0])):
                        os.makedirs(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0]))
                    model.save(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0], save_model_name))
                previous_auc = AUC
                save_epoch = epoch + 1

            # ---------------------------------- recond and print ---------------------------------
            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['train_avg_se'].append(np.average(train_se))
            process_record['train_se_0'].append(train_se[0])
            process_record['train_se_1'].append(train_se[1])
            process_record['val_avg_se'].append(np.average(val_spse))
            process_record['val_se_0'].append(val_spse[0])
            process_record['val_se_1'].append(val_spse[1])
            process_record['AUC'].append(AUC)

            # vis.plot('epoch_mse', epoch_mse.value()[0])
            vis.plot('epoch_syloss', epoch_syloss.value()[0])
            vis.plot_many({'epoch_loss': epoch_loss.value()[0], 'epoch_total_loss': epoch_total_loss.value()[0],
                           'train_avg_se': np.average(train_se), 'train_se_0': train_se[0], 'train_se_1': train_se[1],
                           'val_avg_se': np.average(val_spse), 'val_se_0': val_spse[0], 'val_se_1': val_spse[1],
                           'AUC': AUC})
            vis.log(f"epoch: [{epoch+1}/{config.max_epoch}] =========================================")
            vis.log(f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}")
            vis.log(f"train_avg_se: {round(np.average(train_se), 4)}, train_se_0: {round(train_se[0], 4)}, train_se_1: {round(train_se[1], 4)}")
            vis.log(f"val_avg_se: {round(sum(val_spse)/len(val_spse), 4)}, val_se_0: {round(val_spse[0], 4)}, val_se_1: {round(val_spse[1], 4)}")
            vis.log(f"AUC: {AUC}")
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'Best Threshold: {Best_T}')
            vis.log(f'val_cm: {val_cm}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:", round(epoch_loss.value()[0], 5))
            print('train_avg_se:', round(np.average(train_se), 4), 'train_se_0:', round(train_se[0], 4), 'train_se_1:', round(train_se[1], 4))
            print('val_avg_se:', round(np.average(val_spse), 4), 'val_se_0:', round(val_spse[0], 4), 'val_se_1:', round(val_spse[1], 4))
            print('AUC:', AUC)
            print('train_cm:')
            print(train_cm.value())
            print('Best Threshold:', Best_T, 'val_cm:')
            print(val_cm)

            # ------------------------------------ save record ------------------------------------
            if os.path.exists(os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir, save_model_name.split('.')[0], 'process_record.json'),
                           content=process_record)

        # if (epoch+1) % 5 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)
def train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    vis = Visualizer(port=2333, env=config.env)
    vis.log('Use config:')
    for k, v in config.__class__.__dict__.items():
        if not k.startswith('__'):
            vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data = SlideWindowDataset(config.train_paths,
                                    phase='train',
                                    useRGB=config.useRGB,
                                    usetrans=config.usetrans,
                                    balance=config.data_balance)
    val_data = SlideWindowDataset(config.test_paths,
                                  phase='val',
                                  useRGB=config.useRGB,
                                  usetrans=config.usetrans,
                                  balance=False)
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    dist = train_data.dist()
    print('Train Data Distribution:', dist)

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    model = UNet_Classifier(num_classes=config.num_classes)
    print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
        print('Model loaded')
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(
            model, device_ids=[x for x in range(config.num_of_gpu)])

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1, 1])
    # weight = torch.Tensor([dist['1']/(dist['0']+dist['1']), dist['0']/(dist['0']+dist['1'])])  # weight需要将二者反过来,多于二分类可以取倒数
    # weight = torch.Tensor([1, 3.5])
    # weight = torch.Tensor([1, 5])
    weight = torch.Tensor([1, 7])

    vis.log(f'loss weight: {weight}')
    print('loss weight:', weight)
    weight = weight.cuda()
    criterion = torch.nn.CrossEntropyLoss(weight=weight)
    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    softmax = functional.softmax
    loss_meter_edge = meter.AverageValueMeter()
    epoch_loss_edge = meter.AverageValueMeter()
    loss_meter_cls = meter.AverageValueMeter()
    epoch_loss_cls = meter.AverageValueMeter()
    loss_meter = meter.AverageValueMeter()
    epoch_loss = meter.AverageValueMeter()
    train_cm = meter.ConfusionMeter(config.num_classes)

    # ====================================== Saving and Recording Configuration =================================
    previous_auc = 0
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    save_epoch = 1  # 用于记录验证集上效果最好模型对应的epoch
    process_record = {
        'epoch_loss': [],
        'epoch_loss_edge': [],
        'epoch_loss_cls': [],
        'train_avg_se': [],
        'train_se_0': [],
        'train_se_1': [],
        'val_avg_se': [],
        'val_se_0': [],
        'val_se_1': [],
        'AUC': [],
        'DICE': []
    }  # 用于记录实验过程中的曲线,便于画曲线图

    # ================================================== Training ===============================================
    for epoch in range(config.max_epoch):
        print(
            f"epoch: [{epoch + 1}/{config.max_epoch}] {config.save_model_name[:-4]} =================================="
        )
        train_cm.reset()
        epoch_loss.reset()
        dice = []

        # ****************************************** train ****************************************
        model.train()
        for i, (image, label, edge_mask,
                image_path) in tqdm(enumerate(train_dataloader)):
            loss_meter.reset()

            # ------------------------------------ prepare input ------------------------------------
            if config.use_gpu:
                image = image.cuda()
                label = label.cuda()
                edge_mask = edge_mask.cuda()

            # ---------------------------------- go through the model --------------------------------
            score, score_mask = model(x=image)

            # ----------------------------------- backpropagate -------------------------------------
            optimizer.zero_grad()

            # 分类loss
            loss_cls = criterion(score, label)
            # 对Edge包含pixel加loss
            log_prob_mask = functional.logsigmoid(score_mask)
            count_edge = torch.sum(edge_mask, dim=(1, 2, 3), keepdim=True)
            loss_edge = -1 * torch.mean(
                torch.sum(
                    edge_mask * log_prob_mask, dim=(1, 2, 3), keepdim=True) /
                (count_edge + 1e-8))

            # 对非Edge包含pixel加loss
            r_prob_mask = torch.Tensor([1.0
                                        ]).cuda() - torch.sigmoid(score_mask)
            r_edge_mask = torch.Tensor([1.0]).cuda() - edge_mask
            log_rprob_mask = torch.log(r_prob_mask + 1e-5)
            count_redge = torch.sum(r_edge_mask, dim=(1, 2, 3), keepdim=True)
            loss_redge = -1 * torch.mean(
                torch.sum(r_edge_mask * log_rprob_mask,
                          dim=(1, 2, 3),
                          keepdim=True) / (count_redge + 1e-8))

            # 权重按照前景和背景的像素点数量来算
            w1 = torch.sum(count_edge).item() / (torch.sum(count_edge).item() +
                                                 torch.sum(count_redge).item())
            w2 = torch.sum(count_redge).item() / (
                torch.sum(count_edge).item() + torch.sum(count_redge).item())
            loss = loss_cls + w1 * loss_edge + w2 * loss_redge

            loss.backward()
            optimizer.step()

            # ------------------------------------ record loss ------------------------------------
            loss_meter_edge.add((w1 * loss_edge + w2 * loss_redge).item())
            epoch_loss_edge.add((w1 * loss_edge + w2 * loss_redge).item())
            loss_meter_cls.add(loss_cls.item())
            epoch_loss_cls.add(loss_cls.item())
            loss_meter.add(loss.item())
            epoch_loss.add(loss.item())
            train_cm.add(softmax(score, dim=1).detach(), label.detach())
            dice.append(
                dice_coeff(input=(score_mask > 0.5).float(),
                           target=edge_mask[:, 0, :, :]).item())

            if (i + 1) % config.print_freq == 0:
                vis.plot_many({
                    'loss': loss_meter.value()[0],
                    'loss_edge': loss_meter_edge.value()[0],
                    'loss_cls': loss_meter_cls.value()[0]
                })

        train_se = [
            100. * train_cm.value()[0][0] /
            (train_cm.value()[0][0] + train_cm.value()[0][1]),
            100. * train_cm.value()[1][1] /
            (train_cm.value()[1][0] + train_cm.value()[1][1])
        ]
        train_dice = sum(dice) / len(dice)

        # *************************************** validate ***************************************
        model.eval()
        if (epoch + 1) % 1 == 0:
            Best_T, val_cm, val_spse, val_accuracy, AUC, val_dice = val(
                model, val_dataloader)

            # ------------------------------------ save model ------------------------------------
            if AUC > previous_auc and epoch + 1 > 5:  # 5个epoch之后,当测试集上的平均sensitivity升高时保存模型
                if config.parallel:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4]))
                    model.module.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name[:-4], save_model_name))
                else:
                    if not os.path.exists(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4])):
                        os.makedirs(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4]))
                    model.save(
                        os.path.join('checkpoints', save_model_dir,
                                     save_model_name[:-4], save_model_name))
                previous_auc = AUC
                save_epoch = epoch + 1

            # ---------------------------------- recond and print ---------------------------------
            process_record['epoch_loss'].append(epoch_loss.value()[0])
            process_record['epoch_loss_edge'].append(
                epoch_loss_edge.value()[0])
            process_record['epoch_loss_cls'].append(epoch_loss_cls.value()[0])
            process_record['train_avg_se'].append(np.average(train_se))
            process_record['train_se_0'].append(train_se[0])
            process_record['train_se_1'].append(train_se[1])
            process_record['val_avg_se'].append(np.average(val_spse))
            process_record['val_se_0'].append(val_spse[0])
            process_record['val_se_1'].append(val_spse[1])
            process_record['AUC'].append(AUC)
            process_record['DICE'].append(val_dice)

            vis.plot_many({
                'epoch_loss': epoch_loss.value()[0],
                'epoch_loss_edge': epoch_loss_edge.value()[0],
                'epoch_loss_cls': epoch_loss_cls.value()[0],
                'train_avg_se': np.average(train_se),
                'train_se_0': train_se[0],
                'train_se_1': train_se[1],
                'val_avg_se': np.average(val_spse),
                'val_se_0': val_spse[0],
                'val_se_1': val_spse[1],
                'AUC': AUC,
                'train_dice': train_dice,
                'val_dice': val_dice
            })
            vis.log(
                f"epoch: [{epoch + 1}/{config.max_epoch}] ==============================================="
            )
            vis.log(
                f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}"
            )
            vis.log(
                f"train_avg_se: {round(np.average(train_se), 4)}, train_se_0: {round(train_se[0], 4)}, train_se_1: {round(train_se[1], 4)}"
            )
            vis.log(f"train_dice: {round(train_dice, 4)}")
            vis.log(
                f"val_avg_se: {round(sum(val_spse) / len(val_spse), 4)}, val_se_0: {round(val_spse[0], 4)}, val_se_1: {round(val_spse[1], 4)}"
            )
            vis.log(f"val_dice: {round(val_dice, 4)}")
            vis.log(f"AUC: {AUC}")
            vis.log(f'train_cm: {train_cm.value()}')
            vis.log(f'Best Threshold: {Best_T}')
            vis.log(f'val_cm: {val_cm}')
            print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                  round(epoch_loss.value()[0], 5))
            print('train_avg_se:', round(np.average(train_se), 4),
                  'train_se_0:', round(train_se[0], 4), 'train_se_1:',
                  round(train_se[1], 4))
            print('train_dice:', train_dice)
            print('val_avg_se:', round(np.average(val_spse), 4), 'val_se_0:',
                  round(val_spse[0], 4), 'val_se_1:', round(val_spse[1], 4))
            print('val_dice:', val_dice)
            print('AUC:', AUC)
            print('train_cm:')
            print(train_cm.value())
            print('Best Threshold:', Best_T, 'val_cm:')
            print(val_cm)

            # ------------------------------------ save record ------------------------------------
            if os.path.exists(
                    os.path.join('checkpoints', save_model_dir,
                                 save_model_name.split('.')[0])):
                write_json(file=os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4],
                                             'process_record.json'),
                           content=process_record)
        # if (epoch+1) % 20 == 0:
        #     lr = lr * config.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr

    vis.log(f"Best Epoch: {save_epoch}")
    print("Best Epoch:", save_epoch)
示例#11
0
def train(**kwargs):
    opt.parse(kwargs)
    vis = Visualizer(opt.env)

    model = models.KeypointModel(opt)
    if opt.model_path is not None:
        model.load(opt.model_path)

    model.cuda()
    dataset = Dataset(opt)
    dataloader = t.utils.data.DataLoader(dataset,
                                         opt.batch_size,
                                         num_workers=opt.num_workers,
                                         shuffle=True,
                                         drop_last=True)

    lr1, lr2 = opt.lr1, opt.lr2
    optimizer = model.get_optimizer(lr1, lr2)
    loss_meter = tnt.meter.AverageValueMeter()
    pre_loss = 1e100
    model.save()
    for epoch in range(opt.max_epoch):

        loss_meter.reset()
        start = time.time()

        for ii, (img, gt, weight) in tqdm(enumerate(dataloader)):
            optimizer.zero_grad()
            img = t.autograd.Variable(img).cuda()
            target = t.autograd.Variable(gt).cuda()
            weight = t.autograd.Variable(weight).cuda()
            outputs = model(img)
            loss, loss_list = l2_loss(outputs, target, weight)
            (loss).backward()
            loss_meter.add(loss.data[0])
            optimizer.step()

            # 可视化, 记录, log,print
            if ii % opt.plot_every == 0 and ii > 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                vis_plots = {'loss': loss_meter.value()[0], 'ii': ii}
                vis.plot_many(vis_plots)

                # 随机展示一张图片
                k = t.randperm(img.size(0))[0]
                show = img.data[k].cpu()
                raw = (show * 0.225 + 0.45).clamp(min=0, max=1)

                train_masked_img = mask_img(raw, outputs[-1].data[k][14])
                origin_masked_img = mask_img(raw, gt[k][14])

                vis.img('target', origin_masked_img)
                vis.img('train', train_masked_img)
                vis.img('label', gt[k][14])
                vis.img('predict', outputs[-1].data[k][14].clamp(max=1, min=0))
                paf_img = tool.vis_paf(raw, gt[k][15:])
                train_paf_img = tool.vis_paf(
                    raw, outputs[-1][k].data[15:].clamp(min=-1, max=1))
                vis.img('paf_train', train_paf_img)
                #fig = tool.show_paf(np.transpose(raw.cpu().numpy(),(1,2,0)),gt[k][15:].cpu().numpy().transpose((1,2,0))).get_figure()
                #paf_img = tool.fig2data(fig).astype(np.int32)
                #vis.img('paf',t.from_numpy(paf_img/255).float())
                vis.img('paf', paf_img)
        model.save(loss_meter.value()[0])
        vis.save([opt.env])