def binary_cross_entropy(output, target, name=''): r''' Computes the binary cross entropy (aka logistic loss) between the ``output`` and ``target``. Example: TBA Args: output: the computed posterior probability for a variable to be 1 from the network (typ. a ``sigmoid``) target: ground-truth label, 0 or 1 name (str, optional): the name of the Function instance in the network Returns: :class:`~cntk.ops.functions.Function` ''' from cntk.cntk_py import binary_cross_entropy dtype = get_data_type(output, target) output = sanitize_input(output, dtype) target = sanitize_input(target, dtype) return binary_cross_entropy(output, target, name)
def train(TrnPath, SavePath, savename, imgX, imgY, ans_target, trn_target, bs, epc, imglist, use_existing=False): # Get minibatches of training data and perform model training minibatch_size = bs num_epochs = epc # test = np.zeros((1, 512, 256)) test = np.zeros((1, imgY, imgX)) shape = test.shape print("test image shape:" + str(test.shape)) x = C.input_variable(shape) y = C.input_variable(shape) z = cntk_unet.create_model(x) #dice_coef = cntk_unet.dice_coefficient(z, y) # ll = C.Logistic(z, y) # ce = C.CrossEntropy(z, y) ''' checkpoint_file = "/home/ys/PycharmProjects/cntk-unet/cntk-unet.dnn" if use_existing: z.load_model(checkpoint_file) ''' # Prepare model and trainer lr = learning_rate_schedule(0.00005, UnitType.sample) momentum = C.learners.momentum_as_time_constant_schedule(0) # loss and metric ce = binary_cross_entropy(z, y) pe = binary_cross_entropy(z, y) trainer = C.Trainer( z, (ce, pe), C.learners.adam(z.parameters, lr=lr, momentum=momentum)) file_num = len(imglist) training_errors = [] test_errors = [] sw = time.time() for e in range(0, num_epochs): pattern = "[A-Z]" num = 0 random.shuffle(imglist) for i in range(0, file_num): # http://qiita.com/wanwanland/items/ce272419dde2f95cdabc match = re.search(pattern, imglist[i]) # print("matchgroup:"+match.group()) if match.group() == "A": ansfile = TrnPath + r"/ans/Amode/" + ans_target + r"/" + imglist[ i] trnfile = TrnPath + r"/raw/Amode/" + trn_target + r"/" + imglist[ i] if match.group() == "B": ansfile = TrnPath + r"/ans/Bmode/" + ans_target + r"/" + imglist[ i] trnfile = TrnPath + r"/raw/Bmode/" + trn_target + r"/" + imglist[ i] if match.group() == "Z": ansfile = TrnPath + r"/ans/Zmode/" + ans_target + r"/" + imglist[ i] trnfile = TrnPath + r"/raw/Zmode/" + trn_target + r"/" + imglist[ i] # print(ansfile) # print(trnfile) if i % minibatch_size == 0: training_y = np.array([Img2CntkImg(ansfile, imgX, imgY)]) training_x = np.array([Img2CntkImg(trnfile, imgX, imgY)]) elif i % minibatch_size > 0 and i % minibatch_size < minibatch_size - 1: training_y = np.append( training_y, np.array([Img2CntkImg(ansfile, imgX, imgY, True)]), axis=0) training_x = np.append( training_x, np.array([Img2CntkImg(trnfile, imgX, imgY, False)]), axis=0) elif i % minibatch_size == minibatch_size - 1: training_y = np.append( training_y, np.array([Img2CntkImg(ansfile, imgX, imgY, True)]), axis=0) training_x = np.append( training_x, np.array([Img2CntkImg(trnfile, imgX, imgY, False)]), axis=0) trainer.train_minibatch({x: training_x, y: training_y}) # print("###################") if i == num_epochs - 1: # Measure training error training_errors.append( trainer.test_minibatch({ x: training_x, y: training_y })) print("Epoch:" + str(e) + " Error:" + str(np.mean(training_errors)) + " time:" + str(time.time() - sw)) # print("epoch:" + str(e) + " error:"+str(np.mean(training_errors))+" time:" + str(time.time()-sw)) if e % 50 == 0: trainer.save_checkpoint(SavePath + r"/" + savename + r"_" + str(e) + ".dnn") # print("epoch:" + str(e)) # print("time passed:" + str(time.time() - sw)) training_errors.append( trainer.test_minibatch({ x: training_x, y: training_y })) print("Epoch:" + str(e) + " Error:" + str(np.mean(training_errors)) + " time:" + str(time.time() - sw)) training_errors = [] return trainer