Ejemplo n.º 1
0
def main():
    global best_prec1, use_cuda
    best_prec1 = 0

    seed_torch(seed=args.seed)
    use_cuda = True
    use_cuda = use_cuda and torch.cuda.is_available()

    # create model
    model = ResNet.Multimodal_ResNet(num_class=args.num_classes,
                                     mcbp=args.mcbp,
                                     pretrained=True)

    if use_cuda:
        model = model.cuda()
        # for training on multiple GPUs.
        # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
        # model = torch.nn.DataParallel(model).cuda()
    # cudnn.benchmark = True
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    # criterion = FocalLoss(class_num=args.num_classes, alpha=None, gamma=2, size_average=True)
    criterion = nn.CrossEntropyLoss(reduction='mean')
    if args.weight_decay_fc > 0:
        reg_loss = Regularization(args.weight_decay_fc, p=args.p).to(device)
    else:
        reg_loss = 0
        print("no regularization")
    # compute by compute_mean_std
    # normalize = transforms.Normalize([0.0573, 0.0573, 0.0573], [0.1102, 0.1102, 0.1102])  # nrrd
    train_names, val_names = k_fold_pre(MODEL_DIR + "data_fold.txt",
                                        image_list_file=DATA_IMAGE_LIST,
                                        fold=args.fold)

    k = args.fold_index
    if args.resume:
        if os.path.isfile(args.resume + 'checkpoint' + str(k) + '.pth.tar'):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume + 'checkpoint' + str(k) +
                                    '.pth.tar')
            args.start_epoch = checkpoint['epoch']
            best_prec = checkpoint['best_prec1']
            print('best_prec1:', best_prec)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            print("=> use initial checkpoint")
            checkpoint = torch.load(MODEL_DIR +
                                    "%s/checkpoint_init.pth.tar" % args.name)
            model.load_state_dict(checkpoint['state_dict'])
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
        return 0
    result_filename = MODEL_DIR + "{}/test_result.txt".format(
        args.name, str(k))
    result_file = open(result_filename, 'a')  # 'a'  'w+' 追加
    for v in range(2):
        if v == 0:
            filename = MODEL_DIR + "{}/fold_{}_result_train.txt".format(
                args.name, str(k))
            if os.path.exists(filename):
                continue
            val_dataset = DataSet(data_dir=DATA_DIR,
                                  image_list_file=DATA_IMAGE_LIST,
                                  fold=train_names[k],
                                  transform=True,
                                  fold_num=k,
                                  filepath=args,
                                  mode='train')
        else:
            filename = MODEL_DIR + "{}/fold_{}_result_test.txt".format(
                args.name, str(k))
            if os.path.exists(filename):
                continue
            val_dataset = DataSet(data_dir=DATA_DIR,
                                  image_list_file=DATA_IMAGE_LIST,
                                  fold=val_names[k],
                                  transform=True,
                                  fold_num=k,
                                  filepath=args,
                                  mode='test')  #

        kwargs = {'num_workers': 8, 'pin_memory': True}
        val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                 batch_size=args.batch_size,
                                                 **kwargs)
        epoch = args.start_epoch - 1
        # evaluate on validation set
        val_losses, val_acc, val_auc, output_val, label_val = validate(
            val_loader, model, criterion, reg_loss, epoch, k)
        file = open(filename, 'w')  # 'a'  'w+' 追加
        F_s = nn.Softmax(dim=0)
        for i in range(output_val.size()[0]):
            output = F_s(output_val[i])
            out_write = str(output.cpu().numpy()[1]) + ' ' + str(
                int(label_val[i][1])) + '\n'
            out_write = out_write.replace('[', '').replace(']', '')
            file.write(out_write)
        file.close()

        result_file.write(
            str(k) + ' ' + str(val_acc.avg) + ' ' + str(val_auc) + '\n')
        result_file.close()

        classification_LinearRegression(
            path=MODEL_DIR + "%s/" % args.name,
            train_patient_file="fold_{}_image_names_train.txt".format(k),
            train_slice_result_file="fold_{}_result_train.txt".format(k),
            test_patient_file="fold_{}_image_names_test.txt".format(k),
            test_slice_result_file="fold_{}_result_test.txt".format(k),
            fold=k,
            times_1=args.times[0],
            times_0=args.times[1])
        print('Tests have finished')
def main():
    global best_prec_all, use_cuda, writer

    if args.tensorboard:
        # configure(MODEL_DIR + "%s" % args.name)
        writer = SummaryWriter(MODEL_DIR + "%s" % args.name)
    use_cuda = args.use_cuda and torch.cuda.is_available()
    if args.seed > 0:
        seed_torch(args.seed)  # 固定随机数种子
    # create model
    model = ResNet.Multimodal_ResNet(num_class=args.num_classes,
                                     mcbp=args.mcbp,
                                     pretrained=args.pretrained)

    # input_random = torch.rand(32, 3, 100, 100)
    # if args.tensorboard:
    #     writer.add_graph(model, (input_random, input_random, input_random), True)
    if os.path.exists(MODEL_DIR + "%s/checkpoint_init.pth.tar" % args.name):
        checkpoint = torch.load(MODEL_DIR +
                                "%s/checkpoint_init.pth.tar" % args.name)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        torch.save({'state_dict': model.state_dict()},
                   MODEL_DIR + "%s/checkpoint_init.pth.tar" % args.name)

    if use_cuda:
        model = model.cuda()
        # for training on multiple GPUs.
        # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
        # model = torch.nn.DataParallel(model).cuda()
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define optimizer
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    nesterov=True,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     betas=(0.9, 0.99))
    else:
        print('Please choose true optimizer.')
        return 0

    # 5-fold 数据准备
    train_names, val_names = k_fold_pre(MODEL_DIR + "data_fold.txt",
                                        image_list_file=DATA_IMAGE_LIST,
                                        fold=args.fold)
    output, label, best_acc = [], [], []
    best_prec_all = 0  # 所有fold的概率
    fileaccauc_name = MODEL_DIR + "{}/fold_acc_auc.txt".format(args.name)
    fileaccauc = open(fileaccauc_name, 'a')
    for k in range(args.fold_index, args.fold_index + 1):  # args.fold
        best_prec = 0  # 第k个fold的准确率
        # 读取第k个fold的数据
        train_dataset = DataSet_Mini(data_dir=DATA_DIR,
                                     image_list_file=DATA_IMAGE_LIST,
                                     fold=train_names[k],
                                     transform=True)  # normalize
        val_dataset = DataSet_Mini(data_dir=DATA_DIR,
                                   image_list_file=DATA_IMAGE_LIST,
                                   fold=val_names[k],
                                   transform=True)  #
        kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   **kwargs)  # drop_last=True,
        val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 **kwargs)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume + 'checkpoint' + str(k) +
                              '.pth.tar'):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume + 'checkpoint' + str(k) +
                                        '.pth.tar')
                checkpoint_initial = torch.load(MODEL_DIR +
                                                "/%s/checkpoint_init.pth.tar" %
                                                args.name)
                model.load_state_dict(checkpoint_initial['state_dict'])
                # pretrained transfer mbp
                args.start_epoch = checkpoint['epoch']
                best_prec = checkpoint['best_prec1']
                print('best_prec1:', best_prec)
                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))

            else:
                print("=> no checkpoint found at '{}'".format(args.resume))
                print("=> use initial checkpoint")
                checkpoint = torch.load(MODEL_DIR +
                                        "%s/checkpoint_init.pth.tar" %
                                        args.name)
                model.load_state_dict(checkpoint['state_dict'])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return 0
        # define loss function
        criterion = loss_function(weight_decay_fc=args.weight_decay_fc,
                                  p=args.p)
        epoch_is_best = 0
        filename = MODEL_DIR + "{}/fold_avg_auc.txt".format(args.name)
        file = open(filename, 'a')  # 'a' 新建  'w+' 追加

        for epoch in range(args.start_epoch, args.epochs):

            adjust_learning_rate(optimizer, epoch)

            # train for one epoch
            train_losses, train_acc = train(train_loader, model, criterion,
                                            optimizer, epoch, k)

            # for name, layer in model.named_parameters():
            #     writer.add_histogram('fold' + str(k) + '/' + name + '_grad', layer.grad.cpu().data.numpy(), epoch)
            #     writer.add_histogram('fold' + str(k) + '/' + name + '_data', layer.cpu().data.numpy(), epoch)
            # evaluate on validation set
            val_losses, val_acc, prec1, output_val, label_val, AUROC = validate(
                val_loader, model, criterion, epoch, k)
            print('Accuracy {val_acc.avg:.4f}\t AUC {auc:.4f}'.format(
                val_acc=val_acc, auc=AUROC))
            # 验证集用于验证训练的结果,因此就不用庞大的训练集来验证了,每迭代依次就要进行一次验证
            if args.tensorboard:
                # x = model.conv1.weight.data
                # x = vutils.make_grid(x, normalize=True, scale_each=True)
                # writer.add_image('data' + str(k) + '/weight0', x, epoch)  # Tensor
                writer.add_scalars('data' + str(k) + '/loss', {
                    'train_loss': train_losses.avg,
                    'val_loss': val_losses.avg
                }, epoch)
                writer.add_scalars('data' + str(k) + '/Accuracy', {
                    'train_acc': train_acc.avg,
                    'val_acc': val_acc.avg
                }, epoch)
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec
            if is_best == 1:
                epoch_is_best = epoch
                best_prec = max(prec1, best_prec)  # 这个fold的最高准确率
            best_prec_all = max(prec1, best_prec_all)  # 所有的最高准确率
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec,
                }, is_best, epoch, k)
            # 综合五次的output和label
            best_acc.append([best_prec, epoch_is_best])
            output.append(output_val)
            label.append(label_val)

            out_write = str(AUROC) + '\t'
            file.write(out_write)
        file.write('\n')
        file.close()

        acc_auc_out_write = str(train_acc.avg) + ' ' + str(
            val_acc.avg) + ' ' + str(AUROC) + '\n'
        fileaccauc.write(acc_auc_out_write)
        writer.close()
        print('fold_num: [{}]\t Best accuracy {} \t epoch {}'.format(
            k, best_prec, epoch_is_best))
    print('Best accuracy of all fold: ', best_prec_all)
    fileaccauc.close()
    state = {'output': output, 'label': label, 'best_acc': best_acc}
    torch.save(state, MODEL_DIR + "%s/output_label.pth.tar" % (args.name))