def train(model): start_time = time.time() optimizer = torch.optim.Adam(model.parameters(), lr=LR) diceloss = DiceLoss().cuda() dataloader = get_head_data(BATCH_SIZE) train_data = data_set(BATCH_SIZE) next_layer = MODEL.RUnet_3d(in_channel=1).cuda() save_dir = "./RUnet_model/" if not os.path.exists(save_dir): os.makedirs(save_dir) TRAIN_NUM = len(train_image_list) for epoch in range(EPOCH): if epoch == 600: optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) if epoch == 1000: optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) img_num = random.randint(0, TRAIN_NUM - 1) IMAGE = np.load(train_image_list[i]) LABEL = np.load(train_label_list[i]) start_p = random.randint(0, START_P) end_p = IMAGE.shape[0] patch_n = (end_p - start_p) // Z_GAP - 2 left_gap = (end_p - start_p) % Z_GAP # collection data model.eval() with torch.no_grad(): for i in range(patch_n): if i == 0: image_patch = IMAGE[start_p:start_p + SIZE_Z] image_patch = image_patch[np.newaxis, np.newaxis, :, :, :] input_image = Variable( torch.from_numpy(image_patch)).cuda().float() pred, aux2, aux1, ht = model(input_image) else: start_p += Z_GAP image_patch = IMAGE[start_p:start_p + SIZE_Z] label_patch = LABEL[start_p:start_p + SIZE_Z] image_next = IMAGE[start_p + Z_GAP:start_p + Z_GAP + SIZE_Z] label_next = LABEL[start_p + Z_GAP:start_p + Z_GAP + SIZE_Z] image_patch = image_patch[np.newaxis, :, :, :] label_patch = label_patch[np.newaxis, :, :, :] image_next = image_next[np.newaxis, :, :, :] label_next = label_next[np.newaxis, :, :, :] ht0 = ht.data.cpu().numpy() train_data.store_data(image_patch, label_patch, image_next, label_next, ht0[0]) image_patch = image_patch[np.newaxis, :, :, :, :] input_image = Variable( torch.from_numpy(image_patch)).cuda().float() pred, aux2, aux1, ht = model(input_image, ht) if left_gap != 0: start_p += Z_GAP image_patch = IMAGE[start_p:start_p + SIZE_Z] label_patch = LABEL[start_p:start_p + SIZE_Z] image_next = np.zeros([SIZE_Z, SIZE, SIZE]) label_next = np.zeros([SIZE_Z, SIZE, SIZE]) img = IMAGE[start_p + Z_GAP:] lab = LABEL[start_p + Z_GAP:] image_next[0:left_gap + Z_GAP] = img label_next[0:left_gap + Z_GAP] = lab image_patch = image_patch[np.newaxis, :, :, :] label_patch = label_patch[np.newaxis, :, :, :] image_next = image_next[np.newaxis, :, :, :] label_next = label_next[np.newaxis, :, :, :] ht0 = ht.data.cpu().numpy() train_data.store_data(image_patch, label_patch, image_next, label_next, ht0[0]) # trainning model.train() next_layer.load_state_dict(model.state_dict()) next_layer.eval() for step in range(STEP): optimizer.zero_grad() b_image, b_label, image_next, label_next, ht = train_data.pop_data( ) with torch.no_grad(): image = Variable(torch.from_numpy(b_image)).cuda().float() label = Variable(torch.from_numpy(b_label)).cuda().float() image_next = Variable( torch.from_numpy(image_next)).cuda().float() label_next = Variable( torch.from_numpy(label_next)).cuda().float() ht = Variable(torch.from_numpy(ht)).cuda().float() pred, aux2, aux1, ht = model(image, ht) loss0 = diceloss(pred, label) loss1 = diceloss(aux1, label) loss2 = diceloss(aux2, label) pred_next, aux2, aux1, ht = next_layer(image_next, ht) loss_next0 = diceloss(pred_next, label_next) loss_next1 = diceloss(aux1, label_next) loss_next2 = diceloss(aux2, label_next) loss_next = loss_next0 + 0.8 * loss_next1 + 0.4 * loss_next2 loss = loss0 + 0.8 * loss1 + 0.4 * loss2 + loss_next print("epoch:{},image_n:{},step:{},loss:{},loss_next:{}".format( epoch, img_num, step, loss0, loss_next0)) if optimizer is not None: loss.backward() optimizer.step() if epoch % 10 == 0: for step, (b_image, b_label, image_next, label_next) in enumerate(dataloader): optimizer.zero_grad() with torch.no_grad(): image = Variable(b_image.float().cuda()) label = Variable(b_label.float().cuda()) image_next = Variable(image_next.float().cuda()) label_next = Variable(label_next.float().cuda()) pred, aux2, aux1, ht = model(image) loss0 = diceloss(pred, label) loss1 = diceloss(aux1, label) loss2 = diceloss(aux2, label) pred_next, aux2, aux1, ht = next_layer(image_next, ht) loss_next0 = diceloss(pred_next, label_next) loss_next1 = diceloss(aux1, label_next) loss_next2 = diceloss(aux2, label_next) loss_next = loss_next0 + 0.8 * loss_next1 + 0.4 * loss_next2 loss = loss0 + 0.8 * loss1 + 0.4 * loss2 + loss_next print("epoch:{},step:{},loss:{},loss_next:{}".format( epoch, step, loss0, loss_next0)) if optimizer is not None: loss.backward() optimizer.step() if epoch % 25 == 0: wholetime = int(time.time() - start_time) print("the whole time is {}h".format(wholetime / 3600)) torch.save(model.state_dict(), save_dir + str(epoch) + '_RUnet.pth') wholetime = int(time.time() - start_time) print("the trainning finished!, the whole time is {}h".format(wholetime / 3600)) torch.save(model.state_dict(), save_dir + str(epoch) + '_RUnet.pth')
image_patch = image_patch[np.newaxis, np.newaxis, :, :, :] input_image = Variable( torch.from_numpy(image_patch)).cuda().float() pred, aux2, aux1, ht = model(input_image, ht) pred = pred.data.cpu().numpy() predict[start_p:] = pred[0, 0, 0:left_gap] dice = dice_coff(predict, LABEL) all_dice += dice print("the {}th image's dice cofficient is {}".format(i + 1, dice)) wholetime = int(time.time() - start_time) print("the mean dice cofficient is {}".format(all_dice / count)) print("the whole time is {}h".format(wholetime / 3600)) if __name__ == '__main__': model = MODEL.RUnet_3d(in_channel=1).cuda() phase = sys.argv[1] if phase == "train": print("model training!") train(model) elif phase == "test": print("model testing!") test(model) else: print("wrong input!")