def get_loss_criterion(config):
    """
    Returns the loss function based on provided configuration
    :param config: (dict) a top level configuration object containing the 'loss' key
    :return: an instance of the loss function
    """
    assert 'loss' in config, 'Could not find loss function configuration'
    loss_config = config['loss']
    name = loss_config['name']
    assert name in SUPPORTED_LOSSES, f'Invalid loss: {name}. Supported losses: {SUPPORTED_LOSSES}'

    #ignore_index = loss_config.get('ignore_index', None)

    #TODO: add more loss functions
    if name == 'Aux':
        return MultiAuxillaryElementNLLLoss(3, loss_config['loss_weight'],  config['nclass'])
    def train(self):

        ### load settings ###
        config = self.config  #TODO, fix this
        model = self.model

        # define loss
        #TODO, add more loss
        loss_config = config['loss']
        if loss_config['name'] == 'Aux':
            criterion = MultiAuxillaryElementNLLLoss(
                3, loss_config['loss_weight'], config['nclass'])
        else:
            print('do not support other loss yet')
            quit()

        # dataloader
        validation_config = config['validation']
        loader_config = config['loader']
        args_inference = lambda: None
        if validation_config['metric'] is not None:
            print('prepare the data ... ...')
            filenames = glob(loader_config['datafolder'] + '/*_GT.ome.tif')
            filenames.sort()
            total_num = len(filenames)
            LeaveOut = validation_config['leaveout']
            if len(LeaveOut) == 1:
                if LeaveOut[0] > 0 and LeaveOut[0] < 1:
                    num_train = int(np.floor((1 - LeaveOut[0]) * total_num))
                    shuffled_idx = np.arange(total_num)
                    random.shuffle(shuffled_idx)
                    train_idx = shuffled_idx[:num_train]
                    valid_idx = shuffled_idx[num_train:]
                else:
                    valid_idx = [int(LeaveOut[0])]
                    train_idx = list(
                        set(range(total_num)) - set(map(int, LeaveOut)))
            elif LeaveOut:
                valid_idx = list(map(int, LeaveOut))
                train_idx = list(set(range(total_num)) - set(valid_idx))

            valid_filenames = []
            train_filenames = []
            for fi, fn in enumerate(valid_idx):
                valid_filenames.append(filenames[fn][:-11])
            for fi, fn in enumerate(train_idx):
                train_filenames.append(filenames[fn][:-11])

            args_inference.size_in = config['size_in']
            args_inference.size_out = config['size_out']
            args_inference.OutputCh = validation_config['OutputCh']
            args_inference.nclass = config['nclass']

        else:
            #TODO, update here
            print('need validation')
            quit()

        if loader_config['name'] == 'default':
            from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader
            train_set_loader = DataLoader(
                train_loader(train_filenames, loader_config['PatchPerBuffer'],
                             config['size_in'], config['size_out']),
                num_workers=loader_config['NumWorkers'],
                batch_size=loader_config['batch_size'],
                shuffle=True)
        elif loader_config['name'] == 'focus':
            from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader
            train_set_loader = DataLoader(
                train_loader(train_filenames, loader_config['PatchPerBuffer'],
                             config['size_in'], config['size_out']),
                num_workers=loader_config['NumWorkers'],
                batch_size=loader_config['batch_size'],
                shuffle=True)
        else:
            print('other loader not support yet')
            quit()

        num_iterations = 0
        num_epoch = 0  #TODO: load num_epoch from checkpoint

        start_epoch = num_epoch
        for _ in range(start_epoch, config['epochs'] + 1):

            # sets the model in training mode
            model.train()

            optimizer = None
            optimizer = optim.Adam(model.parameters(),
                                   lr=config['learning_rate'],
                                   weight_decay=config['weight_decay'])

            # check if re-load on training data in needed
            if num_epoch > 0 and num_epoch % loader_config[
                    'epoch_shuffle'] == 0:
                print('shuffling data')
                train_set_loader = None
                train_set_loader = DataLoader(
                    train_loader(train_filenames,
                                 loader_config['PatchPerBuffer'],
                                 config['size_in'], config['size_out']),
                    num_workers=loader_config['NumWorkers'],
                    batch_size=loader_config['batch_size'],
                    shuffle=True)

            # Training starts ...
            epoch_loss = []

            for i, current_batch in tqdm(enumerate(train_set_loader)):

                inputs = Variable(current_batch[0].cuda())
                targets = current_batch[1]
                outputs = model(inputs)

                if len(targets) > 1:
                    for zidx in range(len(targets)):
                        targets[zidx] = Variable(targets[zidx].cuda())
                else:
                    targets = Variable(targets[0].cuda())

                optimizer.zero_grad()
                if len(current_batch) == 3:  # input + target + cmap
                    cmap = Variable(current_batch[2].cuda())
                    loss = criterion(outputs, targets, cmap)
                else:  # input + target
                    loss = criterion(outputs, targets)

                loss.backward()
                optimizer.step()

                epoch_loss.append(loss.data.item())
                num_iterations += 1

            average_training_loss = sum(epoch_loss) / len(epoch_loss)

            # validation
            if num_epoch % validation_config['validate_every_n_epoch'] == 0:
                validation_loss = np.zeros(
                    (len(validation_config['OutputCh']) // 2, ))
                model.eval()

                for img_idx, fn in enumerate(valid_filenames):

                    # target
                    label = np.squeeze(imread(fn + '_GT.ome.tif'))
                    label = np.expand_dims(label, axis=0)

                    # input image
                    input_img = np.squeeze(imread(fn + '.ome.tif'))
                    if len(input_img.shape) == 3:
                        # add channel dimension
                        input_img = np.expand_dims(input_img, axis=0)
                    elif len(input_img.shape) == 4:
                        # assume number of channel < number of Z, make sure channel dim comes first
                        if input_img.shape[0] > input_img.shape[1]:
                            input_img = np.transpose(input_img, (1, 0, 2, 3))

                    # cmap tensor
                    costmap = np.squeeze(imread(fn + '_CM.ome.tif'))

                    # output
                    outputs = model_inference(model, input_img,
                                              model.final_activation,
                                              args_inference)

                    assert len(
                        validation_config['OutputCh']) // 2 == len(outputs)

                    for vi in range(len(outputs)):
                        if label.shape[
                                0] == 1:  # the same label for all output
                            validation_loss[vi] += compute_iou(
                                outputs[vi][0, :, :, :] > 0.5,
                                label[0, :, :, :] ==
                                validation_config['OutputCh'][2 * vi + 1],
                                costmap)
                        else:
                            validation_loss[vi] += compute_iou(
                                outputs[vi][0, :, :, :] > 0.5,
                                label[vi, :, :, :] ==
                                validation_config['OutputCh'][2 * vi + 1],
                                costmap)

                average_validation_loss = validation_loss / len(
                    valid_filenames)
                print(
                    f'Epoch: {num_epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}'
                )
            else:
                print(
                    f'Epoch: {num_epoch}, Training Loss: {average_training_loss}'
                )

            if num_epoch % config['save_every_n_epoch'] == 0:
                save_checkpoint(
                    {
                        'epoch': num_epoch,
                        'num_iterations': num_iterations,
                        'model_state_dict': model.state_dict(),
                        #'best_val_score': self.best_val_score,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'device': str(self.device),
                    },
                    checkpoint_dir=config['checkpoint_dir'],
                    logger=self.logger)
            num_epoch += 1
Ejemplo n.º 3
0
def train(args, model):

    model.train()

    # check logger
    if not args.TestMode and os.path.isfile(args.LoggerName):
        print('logger file exists')
        quit()
    text_file = open(args.LoggerName, 'a')
    print(f'Epoch,Training_Loss,Validation_Loss\n', file=text_file)
    text_file.close()

    # load the correct loss function
    if args.Loss == 'NLL_CM' and args.model == 'unet_2task':
        from aicsmlsegment.custom_loss import MultiTaskElementNLLLoss
        criterion = MultiTaskElementNLLLoss(args.LossWeight, args.nclass)
        print('use 2 task elementwise NLL loss')
    elif args.Loss == 'NLL_CM' and (args.model == 'unet_ds' or args.model == 'unet_xy' \
        or args.model == 'unet_deeper_xy' or args.model == 'unet_xy_d6' \
        or args.model == 'unet_xy_p3' or args.model == 'unet_xy_p2'):
        from aicsmlsegment.custom_loss import MultiAuxillaryElementNLLLoss
        criterion = MultiAuxillaryElementNLLLoss(3, args.LossWeight,
                                                 args.nclass)
        print('use unet with deep supervision loss')
    elif args.Loss == 'NLL_CM' and args.model == 'unet_xy_multi_task':
        from aicsmlsegment.custom_loss import MultiTaskElementNLLLoss
        criterion = MultiTaskElementNLLLoss(args.LossWeight, args.nclass)
        print('use 2 task elementwise NLL loss')

    # prepare the training/validattion filenames
    print('prepare the data ... ...')
    filenames = glob.glob(args.DataPath + '/*_GT.ome.tif')
    filenames.sort()
    total_num = len(filenames)
    if len(args.LeaveOut) == 1:
        if args.LeaveOut[0] > 0 and args.LeaveOut[0] < 1:
            num_train = int(np.floor((1 - args.LeaveOut[0]) * total_num))
            shuffled_idx = np.arange(total_num)
            random.shuffle(shuffled_idx)
            train_idx = shuffled_idx[:num_train]
            valid_idx = shuffled_idx[num_train:]
        else:
            valid_idx = [int(args.LeaveOut[0])]
            train_idx = list(
                set(range(total_num)) - set(map(int, args.LeaveOut)))
    elif args.LeaveOut:
        valid_idx = list(map(int, args.LeaveOut))
        train_idx = list(set(range(total_num)) - set(valid_idx))

    valid_filenames = []
    train_filenames = []
    for fi, fn in enumerate(valid_idx):
        valid_filenames.append(filenames[fn][:-11])
    for fi, fn in enumerate(train_idx):
        train_filenames.append(filenames[fn][:-11])

    # may need a different validation method
    #validation_set_loader = DataLoader(exp_Loader(validation_filenames), num_workers=1, batch_size=1, shuffle=False)

    if args.Augmentation == 'NOAUG_M':
        from aicsmlsegment.DataLoader3D.Universal_Loader import NOAUG_M as train_loader
        print('use no augmentation, with cost map')
    elif args.Augmentation == 'RR_FH_M':
        from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M as train_loader
        print('use flip + rotation augmentation, with cost map')
    elif args.Augmentation == 'RR_FH_M0':
        from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader
        print('use flip + rotation augmentation, with cost map')
    elif args.Augmentation == 'RR_FH_M0C':
        from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader
        print(
            'use flip + rotation augmentation, with cost map, and also count valid pixels'
        )

    # softmax for validation
    softmax = nn.Softmax(dim=1)
    softmax.cuda()

    for epoch in range(args.NumEpochs + 1):

        if epoch % args.EpochPerBuffer == 0:
            print('shuffling training data ... ...')
            random.shuffle(train_filenames)
            train_set_loader = DataLoader(train_loader(train_filenames,
                                                       args.PatchPerBuffer,
                                                       args.size_in,
                                                       args.size_out),
                                          num_workers=args.NumWorkers,
                                          batch_size=args.BatchSize,
                                          shuffle=True)
            print('training data is ready')

        # specific optimizer for this epoch
        optimizer = None
        if len(args.lr) == 1:  # single value
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr[0],
                                   weight_decay=args.WeightDecay)
        elif len(
                args.lr
        ) > 1:  # [stage_1, lr_1, stage_2, lr_2, ..., stage_k, lr_k, lr_final]
            assert len(args.lr) % 2 == 1
            num_training_stage = (len(args.lr) - 1) // 2
            elsecase = True
            for ts in range(num_training_stage):
                if epoch < args.lr[ts * 2]:
                    optimizer = optim.Adam(model.parameters(),
                                           lr=args.lr[ts * 2 + 1],
                                           weight_decay=args.WeightDecay)
                    elsecase = False
                    break
            if elsecase:
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.lr[-1],
                                       weight_decay=args.WeightDecay)
        assert optimizer is not None, f'optimzer setup fails'

        # re-open the logger file
        text_file = open(args.LoggerName, 'a')

        # Training starts ...
        epoch_loss = []
        model.train()

        for step, current_batch in tqdm(enumerate(train_set_loader)):

            inputs = Variable(current_batch[0].cuda())
            targets = current_batch[1]
            #print(inputs.size())
            #print(targets[0].size())
            outputs = model(inputs)
            #print(len(outputs))
            #print(outputs[0].size())

            if len(targets) > 1:
                for zidx in range(len(targets)):
                    targets[zidx] = Variable(targets[zidx].cuda())
            else:
                targets = Variable(targets[0].cuda())

            optimizer.zero_grad()
            if len(current_batch) == 3:  # input + target + cmap
                cmap = Variable(current_batch[2].cuda())
                loss = criterion(outputs, targets, cmap)
            else:  # input + target
                loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data.item())

        # Validation starts ...
        validation_loss = np.zeros((len(args.OutputCh) // 2, ))
        model.eval()

        for img_idx, fn in enumerate(valid_filenames):

            # target
            label_reader = AICSImage(fn + '_GT.ome.tif')  #CZYX
            label = label_reader.data
            label = np.squeeze(label, axis=0)  # 4-D after squeeze

            # when the tif has only 1 channel, the loaded array may have falsely swaped dimensions (ZCYX). we want CZYX
            # (This may also happen in different OS or different package versions)
            # ASSUMPTION: we have more z slices than the number of channels
            if label.shape[1] < label.shape[0]:
                label = np.transpose(label, (1, 0, 2, 3))

            # input image
            input_reader = AICSImage(fn + '.ome.tif')  #CZYX  #TODO: check size
            input_img = input_reader.data
            input_img = np.squeeze(input_img, axis=0)
            if input_img.shape[1] < input_img.shape[0]:
                input_img = np.transpose(input_img, (1, 0, 2, 3))

            # cmap tensor
            costmap_reader = AICSImage(fn + '_CM.ome.tif')  # ZYX
            costmap = costmap_reader.data
            costmap = np.squeeze(costmap, axis=0)
            if costmap.shape[0] == 1:
                costmap = np.squeeze(costmap, axis=0)
            elif costmap.shape[1] == 1:
                costmap = np.squeeze(costmap, axis=1)

            # output
            outputs = model_inference(model, input_img, softmax, args)

            assert len(args.OutputCh) // 2 == len(outputs)

            for vi in range(len(outputs)):
                if label.shape[0] == 1:  # the same label for all output
                    validation_loss[vi] += compute_iou(
                        outputs[vi][0, :, :, :] > 0.5,
                        label[0, :, :, :] == args.OutputCh[2 * vi + 1],
                        costmap)
                else:
                    validation_loss[vi] += compute_iou(
                        outputs[vi][0, :, :, :] > 0.5,
                        label[vi, :, :, :] == args.OutputCh[2 * vi + 1],
                        costmap)

        # print loss
        average_training_loss = sum(epoch_loss) / len(epoch_loss)
        average_validation_loss = validation_loss / len(valid_filenames)
        print(
            f'Epoch: {epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}'
        )
        print(f'{epoch},{average_training_loss},{average_validation_loss}\n',
              file=text_file)
        text_file.close()

        # save the model
        if args.SaveEveryKEpoch > 0 and epoch % args.SaveEveryKEpoch == 0:
            filename = f'{args.model}-{epoch:03}-{args.model_tag}.pth'
            torch.save(model.state_dict(), args.ModelDir + os.sep + filename)
            print(f'save at epoch: {epoch})')