Example #1
0
def train(model):
    start_time = time.time()

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    diceloss = DiceLoss().cuda()
    dataloader = get_head_data(BATCH_SIZE)
    train_data = data_set(BATCH_SIZE)
    next_layer = MODEL.RUnet_3d(in_channel=1).cuda()

    save_dir = "./RUnet_model/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    TRAIN_NUM = len(train_image_list)
    for epoch in range(EPOCH):
        if epoch == 600:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        if epoch == 1000:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
        img_num = random.randint(0, TRAIN_NUM - 1)
        IMAGE = np.load(train_image_list[i])
        LABEL = np.load(train_label_list[i])
        start_p = random.randint(0, START_P)
        end_p = IMAGE.shape[0]

        patch_n = (end_p - start_p) // Z_GAP - 2
        left_gap = (end_p - start_p) % Z_GAP

        # collection data
        model.eval()
        with torch.no_grad():
            for i in range(patch_n):
                if i == 0:
                    image_patch = IMAGE[start_p:start_p + SIZE_Z]
                    image_patch = image_patch[np.newaxis, np.newaxis, :, :, :]
                    input_image = Variable(
                        torch.from_numpy(image_patch)).cuda().float()
                    pred, aux2, aux1, ht = model(input_image)

                else:
                    start_p += Z_GAP
                    image_patch = IMAGE[start_p:start_p + SIZE_Z]
                    label_patch = LABEL[start_p:start_p + SIZE_Z]
                    image_next = IMAGE[start_p + Z_GAP:start_p + Z_GAP +
                                       SIZE_Z]
                    label_next = LABEL[start_p + Z_GAP:start_p + Z_GAP +
                                       SIZE_Z]
                    image_patch = image_patch[np.newaxis, :, :, :]
                    label_patch = label_patch[np.newaxis, :, :, :]
                    image_next = image_next[np.newaxis, :, :, :]
                    label_next = label_next[np.newaxis, :, :, :]

                    ht0 = ht.data.cpu().numpy()
                    train_data.store_data(image_patch, label_patch, image_next,
                                          label_next, ht0[0])

                    image_patch = image_patch[np.newaxis, :, :, :, :]
                    input_image = Variable(
                        torch.from_numpy(image_patch)).cuda().float()
                    pred, aux2, aux1, ht = model(input_image, ht)

            if left_gap != 0:
                start_p += Z_GAP
                image_patch = IMAGE[start_p:start_p + SIZE_Z]
                label_patch = LABEL[start_p:start_p + SIZE_Z]
                image_next = np.zeros([SIZE_Z, SIZE, SIZE])
                label_next = np.zeros([SIZE_Z, SIZE, SIZE])
                img = IMAGE[start_p + Z_GAP:]
                lab = LABEL[start_p + Z_GAP:]
                image_next[0:left_gap + Z_GAP] = img
                label_next[0:left_gap + Z_GAP] = lab
                image_patch = image_patch[np.newaxis, :, :, :]
                label_patch = label_patch[np.newaxis, :, :, :]
                image_next = image_next[np.newaxis, :, :, :]
                label_next = label_next[np.newaxis, :, :, :]

                ht0 = ht.data.cpu().numpy()
                train_data.store_data(image_patch, label_patch, image_next,
                                      label_next, ht0[0])

        # trainning
        model.train()
        next_layer.load_state_dict(model.state_dict())
        next_layer.eval()
        for step in range(STEP):
            optimizer.zero_grad()
            b_image, b_label, image_next, label_next, ht = train_data.pop_data(
            )
            with torch.no_grad():
                image = Variable(torch.from_numpy(b_image)).cuda().float()
                label = Variable(torch.from_numpy(b_label)).cuda().float()
                image_next = Variable(
                    torch.from_numpy(image_next)).cuda().float()
                label_next = Variable(
                    torch.from_numpy(label_next)).cuda().float()
                ht = Variable(torch.from_numpy(ht)).cuda().float()

            pred, aux2, aux1, ht = model(image, ht)

            loss0 = diceloss(pred, label)
            loss1 = diceloss(aux1, label)
            loss2 = diceloss(aux2, label)

            pred_next, aux2, aux1, ht = next_layer(image_next, ht)
            loss_next0 = diceloss(pred_next, label_next)
            loss_next1 = diceloss(aux1, label_next)
            loss_next2 = diceloss(aux2, label_next)
            loss_next = loss_next0 + 0.8 * loss_next1 + 0.4 * loss_next2
            loss = loss0 + 0.8 * loss1 + 0.4 * loss2 + loss_next

            print("epoch:{},image_n:{},step:{},loss:{},loss_next:{}".format(
                epoch, img_num, step, loss0, loss_next0))

            if optimizer is not None:
                loss.backward()
                optimizer.step()

        if epoch % 10 == 0:
            for step, (b_image, b_label, image_next,
                       label_next) in enumerate(dataloader):
                optimizer.zero_grad()
                with torch.no_grad():
                    image = Variable(b_image.float().cuda())
                    label = Variable(b_label.float().cuda())
                    image_next = Variable(image_next.float().cuda())
                    label_next = Variable(label_next.float().cuda())

                pred, aux2, aux1, ht = model(image)

                loss0 = diceloss(pred, label)
                loss1 = diceloss(aux1, label)
                loss2 = diceloss(aux2, label)

                pred_next, aux2, aux1, ht = next_layer(image_next, ht)
                loss_next0 = diceloss(pred_next, label_next)
                loss_next1 = diceloss(aux1, label_next)
                loss_next2 = diceloss(aux2, label_next)
                loss_next = loss_next0 + 0.8 * loss_next1 + 0.4 * loss_next2
                loss = loss0 + 0.8 * loss1 + 0.4 * loss2 + loss_next

                print("epoch:{},step:{},loss:{},loss_next:{}".format(
                    epoch, step, loss0, loss_next0))
                if optimizer is not None:
                    loss.backward()
                    optimizer.step()

        if epoch % 25 == 0:
            wholetime = int(time.time() - start_time)
            print("the whole time is {}h".format(wholetime / 3600))
            torch.save(model.state_dict(),
                       save_dir + str(epoch) + '_RUnet.pth')

    wholetime = int(time.time() - start_time)
    print("the trainning finished!, the whole time is {}h".format(wholetime /
                                                                  3600))
    torch.save(model.state_dict(), save_dir + str(epoch) + '_RUnet.pth')
Example #2
0
                image_patch = image_patch[np.newaxis, np.newaxis, :, :, :]

                input_image = Variable(
                    torch.from_numpy(image_patch)).cuda().float()
                pred, aux2, aux1, ht = model(input_image, ht)
                pred = pred.data.cpu().numpy()
                predict[start_p:] = pred[0, 0, 0:left_gap]

        dice = dice_coff(predict, LABEL)
        all_dice += dice
        print("the {}th image's dice cofficient is {}".format(i + 1, dice))

    wholetime = int(time.time() - start_time)
    print("the mean dice cofficient is {}".format(all_dice / count))
    print("the whole time is {}h".format(wholetime / 3600))


if __name__ == '__main__':
    model = MODEL.RUnet_3d(in_channel=1).cuda()
    phase = sys.argv[1]

    if phase == "train":
        print("model training!")
        train(model)

    elif phase == "test":
        print("model testing!")
        test(model)

    else:
        print("wrong input!")