def train(self):

        ### load settings ###
        config = self.config  #TODO, fix this
        model = self.model

        # define loss
        #TODO, add more loss
        loss_config = config['loss']
        if loss_config['name'] == 'Aux':
            criterion = MultiAuxillaryElementNLLLoss(
                3, loss_config['loss_weight'], config['nclass'])
        else:
            print('do not support other loss yet')
            quit()

        # dataloader
        validation_config = config['validation']
        loader_config = config['loader']
        args_inference = lambda: None
        if validation_config['metric'] is not None:
            print('prepare the data ... ...')
            filenames = glob(loader_config['datafolder'] + '/*_GT.ome.tif')
            filenames.sort()
            total_num = len(filenames)
            LeaveOut = validation_config['leaveout']
            if len(LeaveOut) == 1:
                if LeaveOut[0] > 0 and LeaveOut[0] < 1:
                    num_train = int(np.floor((1 - LeaveOut[0]) * total_num))
                    shuffled_idx = np.arange(total_num)
                    random.shuffle(shuffled_idx)
                    train_idx = shuffled_idx[:num_train]
                    valid_idx = shuffled_idx[num_train:]
                else:
                    valid_idx = [int(LeaveOut[0])]
                    train_idx = list(
                        set(range(total_num)) - set(map(int, LeaveOut)))
            elif LeaveOut:
                valid_idx = list(map(int, LeaveOut))
                train_idx = list(set(range(total_num)) - set(valid_idx))

            valid_filenames = []
            train_filenames = []
            for fi, fn in enumerate(valid_idx):
                valid_filenames.append(filenames[fn][:-11])
            for fi, fn in enumerate(train_idx):
                train_filenames.append(filenames[fn][:-11])

            args_inference.size_in = config['size_in']
            args_inference.size_out = config['size_out']
            args_inference.OutputCh = validation_config['OutputCh']
            args_inference.nclass = config['nclass']

        else:
            #TODO, update here
            print('need validation')
            quit()

        if loader_config['name'] == 'default':
            from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader
            train_set_loader = DataLoader(
                train_loader(train_filenames, loader_config['PatchPerBuffer'],
                             config['size_in'], config['size_out']),
                num_workers=loader_config['NumWorkers'],
                batch_size=loader_config['batch_size'],
                shuffle=True)
        elif loader_config['name'] == 'focus':
            from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader
            train_set_loader = DataLoader(
                train_loader(train_filenames, loader_config['PatchPerBuffer'],
                             config['size_in'], config['size_out']),
                num_workers=loader_config['NumWorkers'],
                batch_size=loader_config['batch_size'],
                shuffle=True)
        else:
            print('other loader not support yet')
            quit()

        num_iterations = 0
        num_epoch = 0  #TODO: load num_epoch from checkpoint

        start_epoch = num_epoch
        for _ in range(start_epoch, config['epochs'] + 1):

            # sets the model in training mode
            model.train()

            optimizer = None
            optimizer = optim.Adam(model.parameters(),
                                   lr=config['learning_rate'],
                                   weight_decay=config['weight_decay'])

            # check if re-load on training data in needed
            if num_epoch > 0 and num_epoch % loader_config[
                    'epoch_shuffle'] == 0:
                print('shuffling data')
                train_set_loader = None
                train_set_loader = DataLoader(
                    train_loader(train_filenames,
                                 loader_config['PatchPerBuffer'],
                                 config['size_in'], config['size_out']),
                    num_workers=loader_config['NumWorkers'],
                    batch_size=loader_config['batch_size'],
                    shuffle=True)

            # Training starts ...
            epoch_loss = []

            for i, current_batch in tqdm(enumerate(train_set_loader)):

                inputs = Variable(current_batch[0].cuda())
                targets = current_batch[1]
                outputs = model(inputs)

                if len(targets) > 1:
                    for zidx in range(len(targets)):
                        targets[zidx] = Variable(targets[zidx].cuda())
                else:
                    targets = Variable(targets[0].cuda())

                optimizer.zero_grad()
                if len(current_batch) == 3:  # input + target + cmap
                    cmap = Variable(current_batch[2].cuda())
                    loss = criterion(outputs, targets, cmap)
                else:  # input + target
                    loss = criterion(outputs, targets)

                loss.backward()
                optimizer.step()

                epoch_loss.append(loss.data.item())
                num_iterations += 1

            average_training_loss = sum(epoch_loss) / len(epoch_loss)

            # validation
            if num_epoch % validation_config['validate_every_n_epoch'] == 0:
                validation_loss = np.zeros(
                    (len(validation_config['OutputCh']) // 2, ))
                model.eval()

                for img_idx, fn in enumerate(valid_filenames):

                    # target
                    label = np.squeeze(imread(fn + '_GT.ome.tif'))
                    label = np.expand_dims(label, axis=0)

                    # input image
                    input_img = np.squeeze(imread(fn + '.ome.tif'))
                    if len(input_img.shape) == 3:
                        # add channel dimension
                        input_img = np.expand_dims(input_img, axis=0)
                    elif len(input_img.shape) == 4:
                        # assume number of channel < number of Z, make sure channel dim comes first
                        if input_img.shape[0] > input_img.shape[1]:
                            input_img = np.transpose(input_img, (1, 0, 2, 3))

                    # cmap tensor
                    costmap = np.squeeze(imread(fn + '_CM.ome.tif'))

                    # output
                    outputs = model_inference(model, input_img,
                                              model.final_activation,
                                              args_inference)

                    assert len(
                        validation_config['OutputCh']) // 2 == len(outputs)

                    for vi in range(len(outputs)):
                        if label.shape[
                                0] == 1:  # the same label for all output
                            validation_loss[vi] += compute_iou(
                                outputs[vi][0, :, :, :] > 0.5,
                                label[0, :, :, :] ==
                                validation_config['OutputCh'][2 * vi + 1],
                                costmap)
                        else:
                            validation_loss[vi] += compute_iou(
                                outputs[vi][0, :, :, :] > 0.5,
                                label[vi, :, :, :] ==
                                validation_config['OutputCh'][2 * vi + 1],
                                costmap)

                average_validation_loss = validation_loss / len(
                    valid_filenames)
                print(
                    f'Epoch: {num_epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}'
                )
            else:
                print(
                    f'Epoch: {num_epoch}, Training Loss: {average_training_loss}'
                )

            if num_epoch % config['save_every_n_epoch'] == 0:
                save_checkpoint(
                    {
                        'epoch': num_epoch,
                        'num_iterations': num_iterations,
                        'model_state_dict': model.state_dict(),
                        #'best_val_score': self.best_val_score,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'device': str(self.device),
                    },
                    checkpoint_dir=config['checkpoint_dir'],
                    logger=self.logger)
            num_epoch += 1
Ejemplo n.º 2
0
def train(args, model):

    model.train()

    # check logger
    if not args.TestMode and os.path.isfile(args.LoggerName):
        print('logger file exists')
        quit()
    text_file = open(args.LoggerName, 'a')
    print(f'Epoch,Training_Loss,Validation_Loss\n', file=text_file)
    text_file.close()

    # load the correct loss function
    if args.Loss == 'NLL_CM' and args.model == 'unet_2task':
        from aicsmlsegment.custom_loss import MultiTaskElementNLLLoss
        criterion = MultiTaskElementNLLLoss(args.LossWeight, args.nclass)
        print('use 2 task elementwise NLL loss')
    elif args.Loss == 'NLL_CM' and (args.model == 'unet_ds' or args.model == 'unet_xy' \
        or args.model == 'unet_deeper_xy' or args.model == 'unet_xy_d6' \
        or args.model == 'unet_xy_p3' or args.model == 'unet_xy_p2'):
        from aicsmlsegment.custom_loss import MultiAuxillaryElementNLLLoss
        criterion = MultiAuxillaryElementNLLLoss(3, args.LossWeight,
                                                 args.nclass)
        print('use unet with deep supervision loss')
    elif args.Loss == 'NLL_CM' and args.model == 'unet_xy_multi_task':
        from aicsmlsegment.custom_loss import MultiTaskElementNLLLoss
        criterion = MultiTaskElementNLLLoss(args.LossWeight, args.nclass)
        print('use 2 task elementwise NLL loss')

    # prepare the training/validattion filenames
    print('prepare the data ... ...')
    filenames = glob.glob(args.DataPath + '/*_GT.ome.tif')
    filenames.sort()
    total_num = len(filenames)
    if len(args.LeaveOut) == 1:
        if args.LeaveOut[0] > 0 and args.LeaveOut[0] < 1:
            num_train = int(np.floor((1 - args.LeaveOut[0]) * total_num))
            shuffled_idx = np.arange(total_num)
            random.shuffle(shuffled_idx)
            train_idx = shuffled_idx[:num_train]
            valid_idx = shuffled_idx[num_train:]
        else:
            valid_idx = [int(args.LeaveOut[0])]
            train_idx = list(
                set(range(total_num)) - set(map(int, args.LeaveOut)))
    elif args.LeaveOut:
        valid_idx = list(map(int, args.LeaveOut))
        train_idx = list(set(range(total_num)) - set(valid_idx))

    valid_filenames = []
    train_filenames = []
    for fi, fn in enumerate(valid_idx):
        valid_filenames.append(filenames[fn][:-11])
    for fi, fn in enumerate(train_idx):
        train_filenames.append(filenames[fn][:-11])

    # may need a different validation method
    #validation_set_loader = DataLoader(exp_Loader(validation_filenames), num_workers=1, batch_size=1, shuffle=False)

    if args.Augmentation == 'NOAUG_M':
        from aicsmlsegment.DataLoader3D.Universal_Loader import NOAUG_M as train_loader
        print('use no augmentation, with cost map')
    elif args.Augmentation == 'RR_FH_M':
        from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M as train_loader
        print('use flip + rotation augmentation, with cost map')
    elif args.Augmentation == 'RR_FH_M0':
        from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader
        print('use flip + rotation augmentation, with cost map')
    elif args.Augmentation == 'RR_FH_M0C':
        from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader
        print(
            'use flip + rotation augmentation, with cost map, and also count valid pixels'
        )

    # softmax for validation
    softmax = nn.Softmax(dim=1)
    softmax.cuda()

    for epoch in range(args.NumEpochs + 1):

        if epoch % args.EpochPerBuffer == 0:
            print('shuffling training data ... ...')
            random.shuffle(train_filenames)
            train_set_loader = DataLoader(train_loader(train_filenames,
                                                       args.PatchPerBuffer,
                                                       args.size_in,
                                                       args.size_out),
                                          num_workers=args.NumWorkers,
                                          batch_size=args.BatchSize,
                                          shuffle=True)
            print('training data is ready')

        # specific optimizer for this epoch
        optimizer = None
        if len(args.lr) == 1:  # single value
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr[0],
                                   weight_decay=args.WeightDecay)
        elif len(
                args.lr
        ) > 1:  # [stage_1, lr_1, stage_2, lr_2, ..., stage_k, lr_k, lr_final]
            assert len(args.lr) % 2 == 1
            num_training_stage = (len(args.lr) - 1) // 2
            elsecase = True
            for ts in range(num_training_stage):
                if epoch < args.lr[ts * 2]:
                    optimizer = optim.Adam(model.parameters(),
                                           lr=args.lr[ts * 2 + 1],
                                           weight_decay=args.WeightDecay)
                    elsecase = False
                    break
            if elsecase:
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.lr[-1],
                                       weight_decay=args.WeightDecay)
        assert optimizer is not None, f'optimzer setup fails'

        # re-open the logger file
        text_file = open(args.LoggerName, 'a')

        # Training starts ...
        epoch_loss = []
        model.train()

        for step, current_batch in tqdm(enumerate(train_set_loader)):

            inputs = Variable(current_batch[0].cuda())
            targets = current_batch[1]
            #print(inputs.size())
            #print(targets[0].size())
            outputs = model(inputs)
            #print(len(outputs))
            #print(outputs[0].size())

            if len(targets) > 1:
                for zidx in range(len(targets)):
                    targets[zidx] = Variable(targets[zidx].cuda())
            else:
                targets = Variable(targets[0].cuda())

            optimizer.zero_grad()
            if len(current_batch) == 3:  # input + target + cmap
                cmap = Variable(current_batch[2].cuda())
                loss = criterion(outputs, targets, cmap)
            else:  # input + target
                loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data.item())

        # Validation starts ...
        validation_loss = np.zeros((len(args.OutputCh) // 2, ))
        model.eval()

        for img_idx, fn in enumerate(valid_filenames):

            # target
            label_reader = AICSImage(fn + '_GT.ome.tif')  #CZYX
            label = label_reader.data
            label = np.squeeze(label, axis=0)  # 4-D after squeeze

            # when the tif has only 1 channel, the loaded array may have falsely swaped dimensions (ZCYX). we want CZYX
            # (This may also happen in different OS or different package versions)
            # ASSUMPTION: we have more z slices than the number of channels
            if label.shape[1] < label.shape[0]:
                label = np.transpose(label, (1, 0, 2, 3))

            # input image
            input_reader = AICSImage(fn + '.ome.tif')  #CZYX  #TODO: check size
            input_img = input_reader.data
            input_img = np.squeeze(input_img, axis=0)
            if input_img.shape[1] < input_img.shape[0]:
                input_img = np.transpose(input_img, (1, 0, 2, 3))

            # cmap tensor
            costmap_reader = AICSImage(fn + '_CM.ome.tif')  # ZYX
            costmap = costmap_reader.data
            costmap = np.squeeze(costmap, axis=0)
            if costmap.shape[0] == 1:
                costmap = np.squeeze(costmap, axis=0)
            elif costmap.shape[1] == 1:
                costmap = np.squeeze(costmap, axis=1)

            # output
            outputs = model_inference(model, input_img, softmax, args)

            assert len(args.OutputCh) // 2 == len(outputs)

            for vi in range(len(outputs)):
                if label.shape[0] == 1:  # the same label for all output
                    validation_loss[vi] += compute_iou(
                        outputs[vi][0, :, :, :] > 0.5,
                        label[0, :, :, :] == args.OutputCh[2 * vi + 1],
                        costmap)
                else:
                    validation_loss[vi] += compute_iou(
                        outputs[vi][0, :, :, :] > 0.5,
                        label[vi, :, :, :] == args.OutputCh[2 * vi + 1],
                        costmap)

        # print loss
        average_training_loss = sum(epoch_loss) / len(epoch_loss)
        average_validation_loss = validation_loss / len(valid_filenames)
        print(
            f'Epoch: {epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}'
        )
        print(f'{epoch},{average_training_loss},{average_validation_loss}\n',
              file=text_file)
        text_file.close()

        # save the model
        if args.SaveEveryKEpoch > 0 and epoch % args.SaveEveryKEpoch == 0:
            filename = f'{args.model}-{epoch:03}-{args.model_tag}.pth'
            torch.save(model.state_dict(), args.ModelDir + os.sep + filename)
            print(f'save at epoch: {epoch})')
Ejemplo n.º 3
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    args = parser.parse_args()

    config = load_config(args.config)

    # declare the model
    model = build_model(config)

    # load the trained model instance
    model_path = config['model_path']
    print(f'Loading model from {model_path}...')
    load_checkpoint(model_path, model)

    # extract the parameters for preparing the input image
    args_norm = lambda: None
    args_norm.Normalization = config['Normalization']

    # extract the parameters for running the model inference
    args_inference = lambda: None
    args_inference.size_in = config['size_in']
    args_inference.size_out = config['size_out']
    args_inference.OutputCh = config['OutputCh']
    args_inference.nclass = config['nclass']

    # run
    inf_config = config['mode']
    if inf_config['name'] == 'file':
        fn = inf_config['InputFile']
        data_reader = AICSImage(fn)
        img0 = data_reader.data

        if inf_config['timelapse']:
            assert img0.shape[0] > 1

            for tt in range(img0.shape[0]):
                # Assume:  dimensions = TCZYX
                img = img0[tt, config['InputCh'], :, :, :].astype(float)
                img = input_normalization(img, args_norm)

                if len(config['ResizeRatio']) > 0:
                    img = resize(
                        img,
                        (1, config['ResizeRatio'][0], config['ResizeRatio'][1],
                         config['ResizeRatio'][2]),
                        method='cubic')
                    for ch_idx in range(img.shape[0]):
                        struct_img = img[ch_idx, :, :, :]
                        struct_img = (struct_img - struct_img.min()) / (
                            struct_img.max() - struct_img.min())
                        img[ch_idx, :, :, :] = struct_img

                # apply the model
                output_img = model_inference(model, img,
                                             model.final_activation,
                                             args_inference)

                # extract the result and write the output
                if len(config['OutputCh']) == 2:
                    writer = omeTifWriter.OmeTifWriter(
                        config['OutputDir'] + os.sep +
                        pathlib.PurePosixPath(fn).stem + '_T_' + f'{tt:03}' +
                        '_struct_segmentation.tiff')
                    out = output_img[0]
                    out = (out - out.min()) / (out.max() - out.min())
                    if len(config['ResizeRatio']) > 0:
                        out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                           1 / config['ResizeRatio'][1],
                                           1 / config['ResizeRatio'][2]),
                                     method='cubic')
                    out = out.astype(np.float32)
                    if config['Threshold'] > 0:
                        out = out > config['Threshold']
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                    writer.save(out)
                else:
                    for ch_idx in range(len(config['OutputCh']) // 2):
                        writer = omeTifWriter.OmeTifWriter(
                            config['OutputDir'] + os.sep +
                            pathlib.PurePosixPath(fn).stem + '_T_' +
                            f'{tt:03}' + '_seg_' +
                            str(config['OutputCh'][2 * ch_idx]) + '.tiff')
                        out = output_img[ch_idx]
                        out = (out - out.min()) / (out.max() - out.min())
                        if len(config['ResizeRatio']) > 0:
                            out = resize(out,
                                         (1.0, 1 / config['ResizeRatio'][0],
                                          1 / config['ResizeRatio'][1],
                                          1 / config['ResizeRatio'][2]),
                                         method='cubic')
                        out = out.astype(np.float32)
                        if config['Threshold'] > 0:
                            out = out > config['Threshold']
                            out = out.astype(np.uint8)
                            out[out > 0] = 255
                        writer.save(out)
        else:
            img = img0[0, :, :, :, :].astype(float)
            print(f'processing one image of size {img.shape}')
            if img.shape[1] < img.shape[0]:
                img = np.transpose(img, (1, 0, 2, 3))
            img = img[config['InputCh'], :, :, :]
            img = input_normalization(img, args_norm)

            if len(config['ResizeRatio']) > 0:
                img = resize(
                    img, (1, config['ResizeRatio'][0],
                          config['ResizeRatio'][1], config['ResizeRatio'][2]),
                    method='cubic')
                for ch_idx in range(img.shape[0]):
                    struct_img = img[
                        ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                    struct_img = (struct_img - struct_img.min()) / (
                        struct_img.max() - struct_img.min())
                    img[ch_idx, :, :, :] = struct_img

            # apply the model
            output_img = model_inference(model, img, model.final_activation,
                                         args_inference)

            # extract the result and write the output
            if len(config['OutputCh']) == 2:
                out = output_img[0]
                out = (out - out.min()) / (out.max() - out.min())
                if len(config['ResizeRatio']) > 0:
                    out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                       1 / config['ResizeRatio'][1],
                                       1 / config['ResizeRatio'][2]),
                                 method='cubic')
                out = out.astype(np.float32)
                print(out.shape)
                if config['Threshold'] > 0:
                    out = out > config['Threshold']
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                writer = omeTifWriter.OmeTifWriter(
                    config['OutputDir'] + os.sep +
                    pathlib.PurePosixPath(fn).stem +
                    '_struct_segmentation.tiff')
                writer.save(out)
            else:
                for ch_idx in range(len(config['OutputCh']) // 2):
                    out = output_img[ch_idx]
                    out = (out - out.min()) / (out.max() - out.min())
                    if len(config['ResizeRatio']) > 0:
                        out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                           1 / config['ResizeRatio'][1],
                                           1 / config['ResizeRatio'][2]),
                                     method='cubic')
                    out = out.astype(np.float32)
                    if config['Threshold'] > 0:
                        out = out > config['Threshold']
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                    writer = omeTifWriter.OmeTifWriter(
                        config['OutputDir'] + os.sep +
                        pathlib.PurePosixPath(fn).stem + '_seg_' +
                        str(config['OutputCh'][2 * ch_idx]) + '.tiff')
                    writer.save(out)
            print(f'Image {fn} has been segmented')

    elif inf_config['name'] == 'folder':
        from glob import glob
        filenames = glob(inf_config['InputDir'] + '/*' +
                         inf_config['DataType'])
        filenames.sort()
        #print(filenames)

        for _, fn in enumerate(filenames):

            # load data
            data_reader = AICSImage(fn)
            img0 = data_reader.data
            img = img0[0, :, :, :, :].astype(float)
            if img.shape[1] < img.shape[0]:
                img = np.transpose(img, (1, 0, 2, 3))
            img = img[config['InputCh'], :, :, :]
            img = input_normalization(img, args_norm)
            #img = image_normalization(img, config['Normalization'])

            if len(config['ResizeRatio']) > 0:
                img = resize(
                    img, (1, config['ResizeRatio'][0],
                          config['ResizeRatio'][1], config['ResizeRatio'][2]),
                    method='cubic')
                for ch_idx in range(img.shape[0]):
                    struct_img = img[
                        ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                    struct_img = (struct_img - struct_img.min()) / (
                        struct_img.max() - struct_img.min())
                    img[ch_idx, :, :, :] = struct_img

            # apply the model
            output_img = model_inference(model, img, model.final_activation,
                                         args_inference)

            # extract the result and write the output
            if len(config['OutputCh']) == 2:
                writer = omeTifWriter.OmeTifWriter(
                    config['OutputDir'] + os.sep +
                    pathlib.PurePosixPath(fn).stem +
                    '_struct_segmentation.tiff')
                if config['Threshold'] < 0:
                    out = output_img[0]
                    out = (out - out.min()) / (out.max() - out.min())
                    print(out.shape)
                    if len(config['ResizeRatio']) > 0:
                        out = resize(out, (1.0, 1 / config['ResizeRatio'][0],
                                           1 / config['ResizeRatio'][1],
                                           1 / config['ResizeRatio'][2]),
                                     method='cubic')
                    out = out.astype(np.float32)
                    out = (out - out.min()) / (out.max() - out.min())
                    writer.save(out)
                else:
                    out = remove_small_objects(
                        output_img[0] > config['Threshold'],
                        min_size=2,
                        connectivity=1)
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                    writer.save(out)
            else:
                for ch_idx in range(len(config['OutputCh']) // 2):
                    writer = omeTifWriter.OmeTifWriter(
                        config['OutputDir'] + os.sep +
                        pathlib.PurePosixPath(fn).stem + '_seg_' +
                        str(config['OutputCh'][2 * ch_idx]) + '.ome.tif')
                    if config['Threshold'] < 0:
                        out = output_img[ch_idx]
                        out = (out - out.min()) / (out.max() - out.min())
                        writer.save(out.astype(np.float32))
                    else:
                        out = output_img[ch_idx] > config['Threshold']
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                        writer.save(out)

            print(f'Image {fn} has been segmented')
Ejemplo n.º 4
0
def evaluate(args, model):

    model.eval()
    softmax = nn.Softmax(dim=1)
    softmax.cuda()

    # check validity of parameters
    assert args.nchannel == len(
        args.InputCh
    ), f'number of input channel does not match input channel indices'

    if args.mode == 'eval':

        filenames = glob.glob(args.InputDir + '/*' + args.DataType)
        filenames.sort()

        for fi, fn in enumerate(filenames):
            print(fn)
            # load data
            struct_img = load_single_image(args, fn, time_flag=False)

            print(struct_img.shape)

            # apply the model
            output_img = apply_on_image(model, struct_img, softmax, args)
            #output_img = model_inference(model, struct_img, softmax, args)

            #print(len(output_img))

            for ch_idx in range(len(args.OutputCh) // 2):
                write = omeTifWriter.OmeTifWriter(
                    args.OutputDir + pathlib.PurePosixPath(fn).stem + '_seg_' +
                    str(args.OutputCh[2 * ch_idx]) + '.ome.tif')
                if args.Threshold < 0:
                    write.save(output_img[ch_idx].astype(float))
                else:
                    out = output_img[ch_idx] > args.Threshold
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                    write.save(out)

            print(f'Image {fn} has been segmented')

    elif args.mode == 'eval_file':

        fn = args.InputFile
        print(fn)
        data_reader = AICSImage(fn)
        img0 = data_reader.data
        if args.timelapse:
            assert data_reader.shape[0] > 1

            for tt in range(data_reader.shape[0]):
                # Assume:  TCZYX
                img = img0[tt, args.InputCh, :, :, :].astype(float)
                img = input_normalization(img, args)

                if len(args.ResizeRatio) > 0:
                    img = resize(img,
                                 (1, args.ResizeRatio[0], args.ResizeRatio[1],
                                  args.ResizeRatio[2]),
                                 method='cubic')
                    for ch_idx in range(img.shape[0]):
                        struct_img = img[
                            ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                        struct_img = (struct_img - struct_img.min()) / (
                            struct_img.max() - struct_img.min())
                        img[ch_idx, :, :, :] = struct_img

                # apply the model
                output_img = model_inference(model, img, softmax, args)

                for ch_idx in range(len(args.OutputCh) // 2):
                    writer = omeTifWriter.OmeTifWriter(
                        args.OutputDir + pathlib.PurePosixPath(fn).stem +
                        '_T_' + f'{tt:03}' + '_seg_' +
                        str(args.OutputCh[2 * ch_idx]) + '.ome.tif')
                    if args.Threshold < 0:
                        out = output_img[ch_idx].astype(float)
                        out = resize(
                            out,
                            (1.0, 1 / args.ResizeRatio[0],
                             1 / args.ResizeRatio[1], 1 / args.ResizeRatio[2]),
                            method='cubic')
                        writer.save(out)
                    else:
                        out = output_img[ch_idx] > args.Threshold
                        out = resize(
                            out,
                            (1.0, 1 / args.ResizeRatio[0],
                             1 / args.ResizeRatio[1], 1 / args.ResizeRatio[2]),
                            method='nearest')
                        out = out.astype(np.uint8)
                        out[out > 0] = 255
                        writer.save(out)
        else:
            img = img0[0, :, :, :].astype(float)
            if img.shape[1] < img.shape[0]:
                img = np.transpose(img, (1, 0, 2, 3))
            img = img[args.InputCh, :, :, :]
            img = input_normalization(img, args)

            if len(args.ResizeRatio) > 0:
                img = resize(img, (1, args.ResizeRatio[0], args.ResizeRatio[1],
                                   args.ResizeRatio[2]),
                             method='cubic')
                for ch_idx in range(img.shape[0]):
                    struct_img = img[
                        ch_idx, :, :, :]  # note that struct_img is only a view of img, so changes made on struct_img also affects img
                    struct_img = (struct_img - struct_img.min()) / (
                        struct_img.max() - struct_img.min())
                    img[ch_idx, :, :, :] = struct_img

            # apply the model
            output_img = model_inference(model, img, softmax, args)

            for ch_idx in range(len(args.OutputCh) // 2):
                writer = omeTifWriter.OmeTifWriter(
                    args.OutputDir + pathlib.PurePosixPath(fn).stem + '_seg_' +
                    str(args.OutputCh[2 * ch_idx]) + '.ome.tif')
                if args.Threshold < 0:
                    writer.save(output_img[ch_idx].astype(float))
                else:
                    out = output_img[ch_idx] > args.Threshold
                    out = out.astype(np.uint8)
                    out[out > 0] = 255
                    writer.save(out)

        print(f'Image {fn} has been segmented')