Example #1
0
 def model_compile_para(self):
     compile_para = dict()
     compile_para["optimizer"] = tf.keras.optimizers.Adam(
         learning_rate=self.learning_rate)
     compile_para["loss"] = {
         "regression": SmoothL1Loss(),
         "classification": FocalLoss()
     }
     return compile_para
Example #2
0
def compute_loss(outputs, labels, loss_method='binary'):
    loss = 0.
    if loss_method == 'binary':
        labels = labels.unsqueeze(1)
        loss = F.binary_cross_entropy(torch.sigmoid(outputs), labels)
    elif loss_method == 'cross_entropy':
        loss = F.cross_entropy(outputs, labels)
    elif loss_method == 'focal_loss':
        loss = FocalLoss()(outputs, labels)
    elif loss_method == 'ghmc':
        loss = GHMC()(outputs, labels)
    return loss
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds,
                            output_attentions=output_attentions,
                            output_hidden_states=output_hidden_states,
                            return_dict=return_dict)

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                self.config.problem_type = 'single_label_classification'
            elif self.config.problem_type != 'single_label_classification':
                raise NotImplementedError(self.__doc__)

            if self.config.problem_type == 'single_label_classification':
                # loss_fct = DiceLoss()
                loss_fct = FocalLoss(gamma=2, alpha=[4, 6, 1], reduction='sum')
                loss = loss_fct(logits.view(-1, self.num_labels),
                                labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return SequenceClassifierOutput(loss=loss,
                                        logits=logits,
                                        hidden_states=outputs.hidden_states,
                                        attentions=outputs.attentions)
Example #4
0
def main():
    lr = 5e-4
    gamma = 0.2
    num_classes = 21
    epoch = 300
    batch_size = 1
    # data_path = '/mnt/storage/project/data/VOCdevkit/VOC2007'
    data_path = '~/datasets/VOC/VOCdevkit/VOC2007'

    # define data.
    data_set = LoadVocDataSets(data_path, 'trainval', AnnotationTransform(),
                               PreProcess(resize=(600, 600)))

    # define model
    model = RetinaNet(num_classes)

    # define criterion
    criterion = FocalLoss(num_classes)

    # define optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)

    # set iteration numbers.
    epoch_size = len(data_set) // batch_size
    max_iter = epoch_size * epoch

    train_loss = 0
    # start iteration
    for iteration in range(max_iter):
        if iteration % epoch_size == 0:
            # create batch iterator
            batch_iter = iter(
                DataLoader(data_set,
                           batch_size,
                           shuffle=True,
                           num_workers=6,
                           collate_fn=data_set.detection_collate))
        images, loc_targets, cls_targets = next(batch_iter)
        optimizer.zero_grad()
        loc_preds, cls_preds = model(images)
        loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        print('train_loss: %.3f ' % (loss.item()))
Example #5
0
    def compute_loss(self, logits, labels):
        if self.loss_type == "ce":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))
        elif self.loss_type == "focal":
            loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="mean")
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))
        elif self.loss_type == "dice":
            loss_fct = DiceLoss(with_logits=True,
                                smooth=self.args.dice_smooth,
                                ohem_ratio=self.args.dice_ohem,
                                alpha=self.args.dice_alpha,
                                square_denominator=self.args.dice_square,
                                reduction="mean")
            loss = loss_fct(logits.view(-1, self.num_classes), labels)
        else:
            raise ValueError

        return loss
Example #6
0
    def __init__(self, args, ckpt):
        super(LossFunction, self).__init__()
        ckpt.write_log('[INFO] Making loss...')

        self.nGPU = args.nGPU
        self.args = args
        self.loss = []
        for loss in args.loss.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'CrossEntropy':
                if args.if_labelsmooth:
                    loss_function = CrossEntropyLabelSmooth(
                        num_classes=args.num_classes)
                    ckpt.write_log('[INFO] Label Smoothing On.')
                else:
                    loss_function = nn.CrossEntropyLoss()
            elif loss_type == 'Triplet':
                loss_function = TripletLoss(args.margin)
            elif loss_type == 'GroupLoss':
                loss_function = GroupLoss(total_classes=args.num_classes,
                                          max_iter=args.T,
                                          num_anchors=args.num_anchors)
            elif loss_type == 'MSLoss':
                loss_function = MultiSimilarityLoss(margin=args.margin)
            elif loss_type == 'Focal':
                loss_function = FocalLoss(reduction='mean')
            elif loss_type == 'OSLoss':
                loss_function = OSM_CAA_Loss()
            elif loss_type == 'CenterLoss':
                loss_function = CenterLoss(num_classes=args.num_classes,
                                           feat_dim=args.feats)

            self.loss.append({
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function
            })

        if len(self.loss) > 1:
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        self.log = torch.Tensor()
Example #7
0
    def __init__(self, args, ckpt):
        super(Loss, self).__init__()
        print('[INFO] Making loss...')

        self.nGPU = args.nGPU
        self.args = args
        self.loss = []
        self.loss_module = nn.ModuleList()
        self.device = torch.device('cpu' if args.cpu else 'cuda')
        for loss in args.loss.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'CrossEntropy':
                loss_function = nn.CrossEntropyLoss()
            elif loss_type == 'Triplet':
                loss_function = TripletSemihardLoss(self.device, args.margin)
            elif loss_type == 'FocalLoss':
                loss_function = FocalLoss(args.num_classes)

            self.loss.append({  #这是个列表,里面每个元素是字典,分别是类型,数量,loss函数
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function
            })

        if len(self.loss) > 1:  #如果是多损失
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        for l in self.loss:
            if l['function'] is not None:
                print('{:.3f} * {}'.format(l['weight'], l['type']))
                self.loss_module.append(l['function'])

        self.log = torch.Tensor()

        device = torch.device('cpu' if args.cpu else 'cuda')
        self.loss_module.to(device)

        if args.load != '': self.load(ckpt.dir, cpu=args.cpu)
        if not args.cpu and args.nGPU > 1:  #多gpu
            self.loss_module = nn.DataParallel(self.loss_module,
                                               range(args.nGPU))
Example #8
0
def get_loss(config):
    """
    returns the loss function
    """

    loss = None

    if config['loss_config'] == 'multibox':
        loss = MultiBoxLoss(class_count=config['class_count'],
                            threshold=config['threshold'],
                            pos_neg_ratio=config['pos_neg_ratio'],
                            use_gpu=config['use_gpu'])

    elif config['loss_config'] == 'focal':
        loss = FocalLoss(class_count=config['class_count'],
                         threshold=config['threshold'],
                         alpha=config['focal_alpha'],
                         gamma=config['focal_gamma'],
                         use_gpu=config['use_gpu'])

    return loss
Example #9
0
                                  drop_last=True)
    val_dataloader = DataLoader(val_data,
                                batch_size=1,
                                shuffle=True,
                                num_workers=opts.n_workers,
                                pin_memory=True,
                                drop_last=False)
    print("Length of train dataloader = ", len(train_dataloader))
    print("Length of validation dataloader = ", len(val_dataloader))

    # define the model
    print("Loading model... ", opts.model, opts.model_depth)
    model, parameters = generate_model(opts)

    criterion = FocalLoss(
        alpha=[0.8, 0.4],
        gamma=2,
        criterion=nn.CrossEntropyLoss(reduction='none').cuda()).cuda()

    log_path = os.path.join(opts.result_path, opts.dataset)
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    if opts.log == 1:
        if opts.resume_path1:
            begin_epoch = int(opts.resume_path1.split('/')[-1].split('_')[1])
            epoch_logger = Logger(os.path.join(
                log_path,
                '{}_train_clip{}model{}{}.log'.format(opts.dataset,
                                                      opts.sample_duration,
                                                      opts.model,
                                                      opts.model_depth)),
def main(args, logger):
    writer = SummaryWriter(
        log_dir=os.path.join('logs', args.dataset, args.model_name, args.loss))

    train_loader, test_loader = load_data(args)
    if args.dataset == 'CIFAR10':
        num_classes = 10
    elif args.dataset == 'CIFAR100':
        num_classes = 100
    elif args.dataset == 'TINY_IMAGENET':
        num_classes = 200
    elif args.dataset == 'IMAGENET':
        num_classes = 1000

    print('Model name :: {}, Dataset :: {}, Num classes :: {}'.format(
        args.model_name, args.dataset, num_classes))
    if args.model_name == 'mixnet_s':
        model = mixnet_s(num_classes=num_classes, dataset=args.dataset)
        # model = mixnet_s(num_classes=num_classes)
    elif args.model_name == 'mixnet_m':
        model = mixnet_m(num_classes=num_classes, dataset=args.dataset)
    elif args.model_name == 'mixnet_l':
        model = mixnet_l(num_classes=num_classes, dataset=args.dataset)
    elif args.model_name == 'ghostnet':
        model = ghostnet(num_classes=num_classes)
    elif args.model_name == 'ghostmishnet':
        model = ghostmishnet(num_classes=num_classes)
    elif args.model_name == 'ghosthmishnet':
        model = ghosthmishnet(num_classes=num_classes)
    elif args.model_name == 'ghostsharkfinnet':
        model = ghostsharkfinnet(num_classes=num_classes)
    elif args.model_name == 'mobilenetv2':
        model = models.mobilenet_v2(num_classes=num_classes)
    elif args.model_name == 'mobilenetv3_s':
        model = mobilenetv3_small(num_classes=num_classes)
    elif args.model_name == 'mobilenetv3_l':
        model = mobilenetv3_large(num_classes=num_classes)
    else:
        raise NotImplementedError

    if args.pretrained_model:
        filename = 'best_model_' + str(args.dataset) + '_' + str(
            args.model_name) + '_ckpt.tar'
        print('filename :: ', filename)
        file_path = os.path.join('./checkpoint', filename)
        checkpoint = torch.load(file_path)

        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
        best_acc1 = checkpoint['best_acc1']
        best_acc5 = checkpoint['best_acc5']
        model_parameters = checkpoint['parameters']
        print(
            'Load model, Parameters: {0}, Start_epoch: {1}, Acc1: {2}, Acc5: {3}'
            .format(model_parameters, start_epoch, best_acc1, best_acc5))
        logger.info(
            'Load model, Parameters: {0}, Start_epoch: {1}, Acc1: {2}, Acc5: {3}'
            .format(model_parameters, start_epoch, best_acc1, best_acc5))
    else:
        start_epoch = 1
        best_acc1 = 0.0
        best_acc5 = 0.0

    if args.cuda:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.cuda()

    print("Number of model parameters: ", get_model_parameters(model))
    logger.info("Number of model parameters: {0}".format(
        get_model_parameters(model)))

    if args.loss == 'ce':
        criterion = nn.CrossEntropyLoss()
    elif args.loss == 'focal':
        criterion = FocalLoss()
    else:
        raise NotImplementedError
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.001)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[30, 60], gamma=0.1)  #learning rate decay

    for epoch in range(start_epoch, args.epochs + 1):
        # adjust_learning_rate(optimizer, epoch, args)
        train(model, train_loader, optimizer, criterion, epoch, args, logger,
              writer)
        acc1, acc5 = eval(model, test_loader, criterion, args)
        lr_scheduler.step()

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if is_best:
            best_acc5 = acc5

        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        filename = 'model_' + str(args.dataset) + '_' + str(
            args.model_name) + '_ckpt.tar'
        print('filename :: ', filename)

        parameters = get_model_parameters(model)

        if torch.cuda.device_count() > 1:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.model_name,
                    'state_dict': model.module.state_dict(),
                    'best_acc1': best_acc1,
                    'best_acc5': best_acc5,
                    'optimizer': optimizer.state_dict(),
                    'parameters': parameters,
                }, is_best, filename)
        else:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.model_name,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'best_acc5': best_acc5,
                    'optimizer': optimizer.state_dict(),
                    'parameters': parameters,
                }, is_best, filename)
        writer.add_scalar('Test/Acc1', acc1, epoch)
        writer.add_scalar('Test/Acc5', acc5, epoch)

        print(" Test best acc1:", best_acc1, " acc1: ", acc1, " acc5: ", acc5)
    writer.close()
Example #11
0
def train():
    args = parse_args()

    if args.use_tfboard:
        writer = SummaryWriter()

    # data loader
    print('load data')
    cfg.DATASET_NAME = args.dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    dataset = VOCDataset(transform=transform, train=True)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                            collate_fn=dataset.collate_fn)

    iter_per_epoch = int(len(dataset) / args.batch_size)

    # load model
    print('load model')
    model = RetinaNet()
    model.load_state_dict(torch.load('./pretrained_model/model.pth'))
    model.freeze_bn()
    if args.use_GPU:
        model = model.cuda()
    if args.mGPUs:
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    model.train()
    # criterion
    criterion = FocalLoss()
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    print('start training')
    for epoch in range(args.epochs):
        train_data_iter = iter(dataloader)
        train_loss = 0
        fg, tp = 0, 0
        for step in range(iter_per_epoch):
            im_data, cls_targets, loc_targets, im_sizes = next(train_data_iter)
            if args.use_GPU:
                im_data = im_data.cuda()
                cls_targets = cls_targets.cuda()
                loc_targets = loc_targets.cuda()
            im_data = Variable(im_data)
            cls_targets = Variable(cls_targets)
            loc_targets = Variable(loc_targets)

            cls_preds, loc_preds = model(im_data)
            cls_loss, loc_loss = criterion(cls_preds, cls_targets, loc_preds, loc_targets)

            loss = cls_loss + loc_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # calculate classification acc
            cls_t = cls_targets.clone()
            cls_t = cls_t.view(-1, cfg.CLASS_NUM)
            cls_max, cls_argmax = torch.max(cls_t, dim=-1)
            fg_inds = torch.eq(cls_max, 1.)
            cls_p = cls_preds.clone()
            cls_p = cls_p.view(-1, cfg.CLASS_NUM)
            pred_info = torch.argmax(cls_p, dim=-1)
            tp += torch.sum(pred_info[fg_inds] == cls_argmax[fg_inds])
            fg += fg_inds.sum()

            if (step + 1) % args.display_interval == 0:
                train_loss /= args.display_interval
                print('[%d epoch | %d step]cls_loss: %.3f | loc_loss: %.3f | avg_loss: %.3f | cls_acc: %.3f' %
                    (epoch, step, cls_loss.item(), loc_loss.item(), train_loss, float(tp)/float(fg)))
                if args.use_tfboard:
                    n_iter = epoch * iter_per_epoch + step + 1
                    writer.add_scalar('losses/loss', train_loss, n_iter)
                    writer.add_scalar('losses/cls_loss', cls_loss.item(), n_iter)
                    writer.add_scalar('losses/loc_loss', loc_loss.item(), n_iter)
                    writer.add_scalar('acc/cls_acc', float(tp) / float(fg), n_iter)
                train_loss = 0
                fg, tp = 0, 0

        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
        if (epoch+1) % args.save_interval == 0:
            print('saving model')
            save_name = os.path.join(args.output_dir, 'retinanet_epoch_{}.pth'.format(epoch + 1))
            torch.save({
                'model': model.state_dict(),
                'epoch': epoch,
            }, save_name)
def train(train_config_file):
    """ Medical image segmentation training engine
    :param train_config_file: the input configuration file
    :return: None
    """
    assert os.path.isfile(train_config_file), 'Config not found: {}'.format(
        train_config_file)

    # load config file
    train_cfg = load_config(train_config_file)

    # clean the existing folder if training from scratch
    model_folder = os.path.join(train_cfg.general.save_dir,
                                train_cfg.general.model_scale)
    if os.path.isdir(model_folder):
        if train_cfg.general.resume_epoch < 0:
            shutil.rmtree(model_folder)
            os.makedirs(model_folder)
    else:
        os.makedirs(model_folder)

    # copy training and inference config files to the model folder
    shutil.copy(train_config_file, os.path.join(model_folder,
                                                'train_config.py'))
    infer_config_file = os.path.join(
        os.path.join(os.path.dirname(__file__), 'config', 'infer_config.py'))
    shutil.copy(infer_config_file,
                os.path.join(train_cfg.general.save_dir, 'infer_config.py'))

    # enable logging
    log_file = os.path.join(model_folder, 'train_log.txt')
    logger = setup_logger(log_file, 'seg3d')

    # control randomness during training
    np.random.seed(train_cfg.general.seed)
    torch.manual_seed(train_cfg.general.seed)
    if train_cfg.general.num_gpus > 0:
        torch.cuda.manual_seed(train_cfg.general.seed)

    # dataset
    train_dataset = SegmentationDataset(
        mode='train',
        im_list=train_cfg.general.train_im_list,
        num_classes=train_cfg.dataset.num_classes,
        spacing=train_cfg.dataset.spacing,
        crop_size=train_cfg.dataset.crop_size,
        sampling_method=train_cfg.dataset.sampling_method,
        random_translation=train_cfg.dataset.random_translation,
        random_scale=train_cfg.dataset.random_scale,
        interpolation=train_cfg.dataset.interpolation,
        crop_normalizers=train_cfg.dataset.crop_normalizers)
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=train_cfg.train.batchsize,
                                   num_workers=train_cfg.train.num_threads,
                                   pin_memory=True,
                                   shuffle=True)

    val_dataset = SegmentationDataset(
        mode='val',
        im_list=train_cfg.general.val_im_list,
        num_classes=train_cfg.dataset.num_classes,
        spacing=train_cfg.dataset.spacing,
        crop_size=train_cfg.dataset.crop_size,
        sampling_method=train_cfg.dataset.sampling_method,
        random_translation=train_cfg.dataset.random_translation,
        random_scale=train_cfg.dataset.random_scale,
        interpolation=train_cfg.dataset.interpolation,
        crop_normalizers=train_cfg.dataset.crop_normalizers)
    val_data_loader = DataLoader(val_dataset,
                                 batch_size=1,
                                 num_workers=1,
                                 shuffle=False)

    # define network
    net = GlobalLocalNetwork(train_dataset.num_modality(),
                             train_cfg.dataset.num_classes)
    net.apply(kaiming_weight_init)
    max_stride = net.max_stride()

    if train_cfg.general.num_gpus > 0:
        net = nn.parallel.DataParallel(net,
                                       device_ids=list(
                                           range(train_cfg.general.num_gpus)))
        net = net.cuda()

    assert np.all(np.array(train_cfg.dataset.crop_size) %
                  max_stride == 0), 'crop size not divisible by max stride'

    # training optimizer
    opt = optim.Adam(net.parameters(),
                     lr=train_cfg.train.lr,
                     betas=train_cfg.train.betas)

    # load checkpoint if resume epoch > 0
    if train_cfg.general.resume_epoch >= 0:
        last_save_epoch = load_checkpoint(train_cfg.general.resume_epoch, net,
                                          opt, model_folder)
    else:
        last_save_epoch = 0

    if train_cfg.loss.name == 'Focal':
        # reuse focal loss if exists
        loss_func = FocalLoss(class_num=train_cfg.dataset.num_classes,
                              alpha=train_cfg.loss.obj_weight,
                              gamma=train_cfg.loss.focal_gamma,
                              use_gpu=train_cfg.general.num_gpus > 0)
    else:
        raise ValueError('Unknown loss function')

    writer = SummaryWriter(os.path.join(model_folder, 'tensorboard'))

    max_avg_dice = 0
    for epoch_idx in range(1, train_cfg.train.epochs + 1):
        train_one_epoch(net, train_cfg.loss.branch_weight, opt,
                        train_data_loader, train_cfg.dataset.down_sample_ratio,
                        loss_func, train_cfg.general.num_gpus,
                        epoch_idx + last_save_epoch, logger, writer,
                        train_cfg.train.print_freq,
                        train_cfg.debug.save_inputs,
                        os.path.join(model_folder, 'debug'))

        # evaluation
        if epoch_idx % train_cfg.train.save_epochs == 0:
            avg_dice = evaluate_one_epoch(
                net, val_data_loader, train_cfg.dataset.crop_size,
                train_cfg.dataset.down_sample_ratio,
                train_cfg.dataset.crop_normalizers[0], Metrics(),
                [idx for idx in range(1, train_cfg.dataset.num_classes)],
                train_cfg.loss.branch_type)

            if max_avg_dice < avg_dice:
                max_avg_dice = avg_dice
                save_checkpoint(net, opt, epoch_idx, train_cfg, max_stride, 1)
                msg = 'epoch: {}, best dice ratio: {}'

            else:
                msg = 'epoch: {},  dice ratio: {}'

            msg = msg.format(epoch_idx, avg_dice)
            logger.info(msg)
Example #13
0
def run_train():
    assert torch.cuda.is_available(), 'Error: CUDA not found!'
    start_epoch = 0  # start from epoch 0 or last epoch

    # Data
    print('Load ListDataset')
    transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    trainset = ListDataset(img_dir=config.img_dir,
                           list_filename=config.train_list_filename,
                           label_map_filename=config.label_map_filename,
                           train=True,
                           transform=transform,
                           input_size=config.img_res)
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=trainset.collate_fn)

    testset = ListDataset(img_dir=config.img_dir,
                          list_filename=config.test_list_filename,
                          label_map_filename=config.label_map_filename,
                          train=False,
                          transform=transform,
                          input_size=config.img_res)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=config.test_batch_size,
                                             shuffle=False,
                                             num_workers=8,
                                             collate_fn=testset.collate_fn)

    # Model
    net = RetinaNet()

    if os.path.exists(config.checkpoint_filename):
        print('Load saved checkpoint: {}'.format(config.checkpoint_filename))
        checkpoint = torch.load(config.checkpoint_filename)
        net.load_state_dict(checkpoint['net'])
        best_loss = checkpoint['loss']
        start_epoch = checkpoint['epoch']
    else:
        print('Load pretrained model: {}'.format(config.pretrained_filename))
        if not os.path.exists(config.pretrained_filename):
            import_pretrained_resnet()
        net.load_state_dict(torch.load(config.pretrained_filename))

    net = torch.nn.DataParallel(net,
                                device_ids=range(torch.cuda.device_count()))
    net.cuda()

    criterion = FocalLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=1e-3,
                          momentum=0.9,
                          weight_decay=1e-4)

    # Training
    def train(epoch):
        print('\nEpoch: %d' % epoch)
        net.train()
        net.module.freeze_bn()
        train_loss = 0

        total_batches = int(
            math.ceil(trainloader.dataset.num_samples /
                      trainloader.batch_size))

        for batch_idx, targets in enumerate(trainloader):
            inputs = targets[0]
            loc_targets = targets[1]
            cls_targets = targets[2]

            inputs = inputs.cuda()
            loc_targets = loc_targets.cuda()
            cls_targets = cls_targets.cuda()

            optimizer.zero_grad()
            loc_preds, cls_preds = net(inputs)
            loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.data
            print('[%d| %d/%d] loss: %.3f | avg: %.3f' %
                  (epoch, batch_idx, total_batches, loss.data, train_loss /
                   (batch_idx + 1)))

    # Test
    def test(epoch):
        print('\nTest')
        net.eval()
        test_loss = 0

        total_batches = int(
            math.ceil(testloader.dataset.num_samples / testloader.batch_size))

        for batch_idx, targets in enumerate(testloader):
            inputs = targets[0]
            loc_targets = targets[1]
            cls_targets = targets[2]

            inputs = inputs.cuda()
            loc_targets = loc_targets.cuda()
            cls_targets = cls_targets.cuda()

            loc_preds, cls_preds = net(inputs)
            loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
            test_loss += loss.data
            print('[%d| %d/%d] loss: %.3f | avg: %.3f' %
                  (epoch, batch_idx, total_batches, loss.data, test_loss /
                   (batch_idx + 1)))

        # Save checkpoint
        global best_loss
        test_loss /= len(testloader)
        if test_loss < best_loss:
            print('Save checkpoint: {}'.format(config.checkpoint_filename))
            state = {
                'net': net.module.state_dict(),
                'loss': test_loss,
                'epoch': epoch,
            }
            if not os.path.exists(os.path.dirname(config.checkpoint_filename)):
                os.makedirs(os.path.dirname(config.checkpoint_filename))
            torch.save(state, config.checkpoint_filename)
            best_loss = test_loss

    for epoch in range(start_epoch, start_epoch + 1000):
        train(epoch)
        test(epoch)
Example #14
0
    def compute_loss(self,
                     start_logits,
                     end_logits,
                     span_logits,
                     start_labels,
                     end_labels,
                     match_labels,
                     start_label_mask,
                     end_label_mask,
                     answerable_cls_logits=None,
                     answerable_cls_labels=None):
        batch_size, seq_len = start_logits.size()[0], start_logits.size()[1]
        start_float_label_mask = start_label_mask.view(-1).float()
        end_float_label_mask = end_label_mask.view(-1).float()
        match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(
            -1, -1, seq_len)
        match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(
            -1, seq_len, -1)
        match_label_mask = match_label_row_mask & match_label_col_mask
        # torch.triu -> returns the upper triangular part of a matrix or batch of matrces input,
        # the other elements of the result tensor are set to 0.
        # an named entity should have the start position which is smaller or equal to the end position.
        match_label_mask = torch.triu(match_label_mask,
                                      0)  # start should be less equal to end

        if self.args.span_loss_candidates == "all":
            # naive mask
            float_match_label_mask = match_label_mask.view(batch_size,
                                                           -1).float()
        else:
            # use only pred or golden start/end to compute match loss
            logits_size = start_logits.shape[-1]
            if logits_size == 1:
                start_preds, end_preds = start_logits > 0, end_logits > 0
                start_preds, end_preds = torch.squeeze(
                    start_preds, dim=-1), torch.squeeze(end_preds, dim=-1)
            elif logits_size == 2:
                start_preds, end_preds = torch.argmax(
                    start_logits, dim=-1), torch.argmax(end_logits, dim=-1)
            else:
                raise ValueError

            if self.args.span_loss_candidates == "gold":
                match_candidates = (
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
                    & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
            elif self.args.span_loss_candidates == "gold_random":
                gold_matrix = (
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
                    & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
                data_generator = torch.Generator()
                data_generator.manual_seed(self.args.seed)
                random_matrix = torch.empty(batch_size, seq_len,
                                            seq_len).uniform_(0, 1)
                random_matrix = torch.bernoulli(
                    random_matrix, generator=data_generator).long()
                random_matrix = random_matrix.cuda()
                match_candidates = torch.logical_or(gold_matrix, random_matrix)
            elif self.args.span_loss_candidates == "gold_pred":
                match_candidates = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)))
            elif self.args.span_loss_candidates == "gold_pred_random":
                gold_and_pred = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)))
                data_generator = torch.Generator()
                data_generator.manual_seed(self.args.seed)
                random_matrix = torch.empty(batch_size, seq_len,
                                            seq_len).uniform_(0, 1)
                random_matrix = torch.bernoulli(
                    random_matrix, generator=data_generator).long()
                random_matrix = random_matrix.cuda()
                match_candidates = torch.logical_or(gold_and_pred,
                                                    random_matrix)
            else:
                raise ValueError
            match_label_mask = match_label_mask & match_candidates
            float_match_label_mask = match_label_mask.view(batch_size,
                                                           -1).float()

        if self.loss_type == "bce":
            start_end_logits_size = start_logits.shape[-1]
            if start_end_logits_size == 1:
                loss_fct = BCEWithLogitsLoss(reduction="none")
                start_loss = loss_fct(start_logits.view(-1),
                                      start_labels.view(-1).float())
                start_loss = (start_loss * start_float_label_mask
                              ).sum() / start_float_label_mask.sum()
                end_loss = loss_fct(end_logits.view(-1),
                                    end_labels.view(-1).float())
                end_loss = (end_loss * end_float_label_mask
                            ).sum() / end_float_label_mask.sum()
            elif start_end_logits_size == 2:
                loss_fct = CrossEntropyLoss(reduction='none')
                start_loss = loss_fct(start_logits.view(-1, 2),
                                      start_labels.view(-1))
                start_loss = (start_loss * start_float_label_mask
                              ).sum() / start_float_label_mask.sum()
                end_loss = loss_fct(end_logits.view(-1, 2),
                                    end_labels.view(-1))
                end_loss = (end_loss * end_float_label_mask
                            ).sum() / end_float_label_mask.sum()
            else:
                raise ValueError

            if span_logits is not None:
                loss_fct = BCEWithLogitsLoss(reduction="mean")
                select_span_logits = torch.masked_select(
                    span_logits.view(-1),
                    match_label_mask.view(-1).bool())
                select_span_labels = torch.masked_select(
                    match_labels.view(-1),
                    match_label_mask.view(-1).bool())
                match_loss = loss_fct(select_span_logits.view(-1, 1),
                                      select_span_labels.float().view(-1, 1))
            else:
                match_loss = None

            if answerable_cls_logits is not None:
                loss_fct = BCEWithLogitsLoss(reduction="mean")
                answerable_loss = loss_fct(
                    answerable_cls_logits.view(-1, 1),
                    answerable_cls_labels.float().view(-1, 1))
            else:
                answerable_loss = None

        elif self.loss_type in ["dice", "adaptive_dice"]:
            # compute span loss
            loss_fct = DiceLoss(with_logits=True,
                                smooth=self.args.dice_smooth,
                                ohem_ratio=self.args.dice_ohem,
                                alpha=self.args.dice_alpha,
                                square_denominator=self.args.dice_square,
                                reduction="mean",
                                index_label_position=False)
            start_end_logits_size = start_logits.shape[-1]
            start_loss = loss_fct(
                start_logits.view(-1, start_end_logits_size),
                start_labels.view(-1, 1),
            )
            end_loss = loss_fct(
                end_logits.view(-1, start_end_logits_size),
                end_labels.view(-1, 1),
            )

            if span_logits is not None:
                select_span_logits = torch.masked_select(
                    span_logits.view(-1),
                    match_label_mask.view(-1).bool())
                select_span_labels = torch.masked_select(
                    match_labels.view(-1),
                    match_label_mask.view(-1).bool())
                match_loss = loss_fct(
                    select_span_logits.view(-1, 1),
                    select_span_labels.view(-1, 1),
                )
            else:
                match_loss = None

            if answerable_cls_logits is not None:
                answerable_loss = loss_fct(answerable_cls_logits.view(-1, 1),
                                           answerable_cls_labels.view(-1, 1))
            else:
                answerable_loss = None

        else:
            loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="none")
            start_loss = loss_fct(
                FocalLoss.convert_binary_pred_to_two_dimension(
                    start_logits.view(-1)), start_labels.view(-1))
            start_loss = (start_loss * start_float_label_mask
                          ).sum() / start_float_label_mask.sum()
            end_loss = loss_fct(
                FocalLoss.convert_binary_pred_to_two_dimension(
                    end_logits.view(-1)), end_labels.view(-1))
            end_loss = (end_loss * end_float_label_mask
                        ).sum() / end_float_label_mask.sum()
            if answerable_cls_logits is not None:
                answerable_loss = loss_fct(
                    FocalLoss.convert_binary_pred_to_two_dimension(
                        answerable_cls_logits.view(-1)),
                    answerable_cls_labels.view(-1))
                answerable_loss = answerable_loss.mean()
            else:
                answerable_loss = None

            if span_logits is not None:
                match_loss = loss_fct(
                    FocalLoss.convert_binary_pred_to_two_dimension(
                        span_logits.view(-1)), match_labels.view(-1))
                match_loss = match_loss * float_match_label_mask.view(-1)
                match_loss = match_loss.sum() / (float_match_label_mask.sum() +
                                                 1e-10)
            else:
                match_loss = None

        if answerable_loss is not None:
            return start_loss, end_loss, match_loss, answerable_loss
        return start_loss, end_loss, match_loss
def train():
    args = parse_args(base_dir, model_dir, total_epochs, batch_size, lr)

    trainset = VOCDataset(base_dir=args.base_dir,
                          split="train",
                          transform=transforms.Compose([
                              RandomScaleCrop(550, 512),
                              RandomHorizontalFlip(),
                              Normalize(mean=(0.485, 0.456, 0.406),
                                        std=(0.229, 0.224, 0.225)),
                              ToTensor()
                          ]))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch,
                             shuffle=True,
                             num_workers=4)

    print("starting loading the net and model")
    # net = Res34Unet(3, 21)
    net = PAN(3, 21)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.cuda:
        args.gpus = [int(x) for x in args.gpus.split(",")]
        net = nn.DataParallel(net, device_ids=args.gpus)

    net.to(device)

    # loss and optimizer
    criterion = FocalLoss()
    # criterion = MultiLovaszLoss()
    optimizer = optim.SGD(net.parameters(), lr=float(args.lr), momentum=0.9)
    # scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=5,
    #                                 min_lr=0.00001)
    scheduler = StepLR(optimizer, step_size=40, gamma=0.1)

    start_epoch = 1
    # Resuming training
    if args.resume_from is not None:
        if not os.path.isfile(args.resume_from):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume_from)
        #args.start_epoch = checkpoint['epoch']
        if args.cuda:
            net.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            net.load_state_dict(checkpoint['model_state_dict'])
        print("resuming training from {}, epoch:{}"\
         .format(args.resume_from, checkpoint['epoch']))
        start_epoch = checkpoint['epoch'] + 1

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    print("finishing loading the net and model")

    print("start training")
    for epoch in range(start_epoch, args.epoch + start_epoch):
        scheduler.step()
        epoch_loss = 0.0
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data["image"].to(device), data["mask"].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_loss += loss.item()
            if i % 10 == 9:  # print every 10 mini-batches
                print("epoch %2d, [%5d / %5d], lr: %5g, loss: %.3f" %
                      (epoch, (i + 1) * args.batch, len(trainset),
                       scheduler.get_lr()[0], running_loss / 10))
                # print("epoch %2d, [%5d / %5d], loss: %.5f" %
                #     (epoch, (i + 1) * args.batch, len(trainset), running_loss / 10))
                running_loss = 0.0

        # scheduler.step(epoch_loss / math.ceil(len(trainset) / args.batch))

        # save model
        torch.save(
            {
                "epoch":
                epoch,
                "model_state_dict":
                net.module.state_dict() if args.cuda else net.state_dict(),
                "optimizer_state_dict":
                optimizer.state_dict(),
                # "lr": scheduler.get_lr()[0]
            },
            os.path.join(model_dir, "epoch_{}.pth".format(epoch)))
    print("Finished training")
Example #16
0
def train(args, tasks_archive, model):
    torch.backends.cudnn.benchmark = True

    if args.resume_ckp != '':
        logger.info('==> loading checkpoint: {}'.format(args.ckp))
        checkpoint = torch.load(args.resume_ckp)

    model = nn.parallel.DataParallel(model)

    logger.info('  + model num_params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if config.use_gpu:
        model.cuda()  # required bofore optimizer?
    #     cudnn.benchmark = True

    print(model)  # especially useful for debugging model structure.
    # summary(model, input_size=tuple([config.num_modality]+config.patch_size)) # takes some time. comment during debugging. ouput each layer's out shape.
    # for name, m in model.named_modules():
    #     logger.info('module name:{}'.format(name))
    #     print(m)

    # lr
    lr = config.base_lr
    if args.resume_ckp != '':
        optimizer = checkpoint['optimizer']
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     weight_decay=config.weight_decay)  #

    # loss
    dice_loss = MulticlassDiceLoss()
    ce_loss = nn.CrossEntropyLoss()
    focal_loss = FocalLoss(gamma=2)

    # prep data
    tasks = args.tasks  # list
    tb_loaders = list()  # train batch loader
    len_loader = list()
    for task in tasks:
        tb_loader = tb_load(task)
        tb_loader.enQueue(tasks_archive[task]['fold' + str(args.fold)],
                          config.patch_size)
        tb_loaders.append(tb_loader)
        len_loader.append(len(tb_loader))
    min_len_loader = np.min(len_loader)

    # init train values
    if args.resume_ckp != '':
        trLoss_queue = checkpoint['trLoss_queue']
        last_trLoss_ma = checkpoint['last_trLoss_ma']
    else:
        trLoss_queue = deque(
            maxlen=config.trLoss_win
        )  # queue to store exponential moving average of total loss in last N epochs
        last_trLoss_ma = None  # the previous one.
    trLoss_queue_list = [
        deque(maxlen=config.trLoss_win) for i in range(len(tasks))
    ]
    last_trLoss_ma_list = [None] * len(tasks)
    trLoss_ma_list = [None] * len(tasks)

    if args.resume_epoch > 0:
        start_epoch = args.resume_epoch + 1
        iterations = args.resume_epoch * config.step_per_epoch + 1
    else:
        start_epoch = 1
        iterations = 1
    logger.info('start epoch: {}'.format(start_epoch))

    ## run train
    for epoch in range(start_epoch, config.max_epoch + 1):
        logger.info('    ----- training epoch {} -----'.format(epoch))
        epoch_st_time = time.time()
        model.train()
        loss_epoch = 0.0
        loss_epoch_list = [0] * len(tasks)
        num_batch_processed = 0  # growing
        num_batch_processed_list = [0] * len(tasks)

        for step in tqdm(range(config.step_per_epoch),
                         desc='{}: epoch{}'.format(args.trainMode, epoch)):
            config.step = iterations
            config.task_idx = (iterations - 1) % len(tasks)
            config.task = tasks[config.task_idx]
            # import ipdb; ipdb.set_trace()

            # tb show lr
            config.writer.add_scalar('data/lr', lr, iterations - 1)

            st_time = time.time()
            for idx in range(len(tasks)):
                tb_loaders[idx].check_process()
            # import ipdb; ipdb.set_trace()
            (batchImg, batchLabel, batchWeight,
             batchAugs) = tb_loaders[config.task_idx].gen_batch(
                 config.batch_size, config.patch_size)
            # logger.info('idx{}_{}, gen_batch time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            st_time = time.time()
            batchImg = torch.from_numpy(batchImg).float(
            )  # change all inputs to same torch tensor type
            batchLabel = torch.from_numpy(batchLabel).float()
            batchWeight = torch.from_numpy(batchWeight).float()

            if config.use_gpu:
                batchImg = batchImg.cuda()
                batchLabel = batchLabel.cuda()
                batchWeight = batchWeight.cuda()
            # logger.info('idx{}_{}, .cuda time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            optimizer.zero_grad()

            st_time = time.time()
            if config.trainMode in ["universal"]:
                output, share_map, para_map = model(batchImg)
            else:
                output = model(batchImg)
            # logger.info('idx{}_{}, model() time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            st_time = time.time()
            # tensorboard visualization of training
            for i in range(len(tasks)):
                if iterations > 200 and iterations % 1000 == i:
                    tb_images([
                        batchImg[0, 0, ...], batchLabel[0, ...],
                        torch.argmax(output[0, ...], dim=0)
                    ], [False, True, True], ['image', 'GT', 'PS'],
                              iterations,
                              tag='Train_idx{}_{}_batch{}_{}'.format(
                                  config.task_idx, config.task, 0,
                                  '_'.join(batchAugs[0])))

                    tb_images([
                        batchImg[config.batch_size - 1, 0, ...],
                        batchLabel[config.batch_size - 1, ...],
                        torch.argmax(output[config.batch_size - 1, ...], dim=0)
                    ], [False, True, True], ['image', 'GT', 'PS'],
                              iterations,
                              tag='Train_idx{}_{}_batch{}_{}_step{}'.format(
                                  config.task_idx, config.task,
                                  config.batch_size - 1,
                                  '_'.join(batchAugs[config.batch_size - 1]),
                                  iterations - 1))
                    if config.trainMode == "universal":
                        logger.info(
                            'share_map shape:{}, para_map shape:{}'.format(
                                str(share_map.shape), str(para_map.shape)))
                        tb_images([
                            para_map[0, :, 64, ...], share_map[0, :, 64, ...]
                        ], [False, False], ['last_para_map', 'last_share_map'],
                                  iterations,
                                  tag='Train_idx{}_{}_para_share_maps_channels'
                                  .format(config.task_idx, config.task))

            logger.info(
                '----- {}, train epoch {} time elapsed:{} -----'.format(
                    config.task, epoch, tinies.timer(epoch_st_time,
                                                     time.time())))

            st_time = time.time()

            output_softmax = F.softmax(output, dim=1)

            loss = lovasz_softmax(output_softmax, batchLabel,
                                  ignore=10) + focal_loss(output, batchLabel)

            loss.backward()
            optimizer.step()

            # logger.info('idx{}_{}, backward time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time())))

            # loss.data.item()
            config.writer.add_scalar('data/loss_step', loss.item(), iterations)
            config.writer.add_scalar(
                'data/loss_step_idx{}_{}'.format(config.task_idx, config.task),
                loss.item(), iterations)

            loss_epoch += loss.item()
            num_batch_processed += 1

            loss_epoch_list[config.task_idx] += loss.item()
            num_batch_processed_list[config.task_idx] += 1

            iterations += 1

        # import ipdb; ipdb.set_trace()
        if epoch % config.save_epoch == 0:
            ckp_path = os.path.join(
                config.log_dir,
                '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode,
                                                  '_'.join(args.tasks), epoch,
                                                  tinies.datestr()))
            torch.save(
                {
                    'epoch': epoch,
                    'model': model,
                    'model_state_dict': model.state_dict(),
                    'optimizer': optimizer,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'trLoss_queue': trLoss_queue,
                    'last_trLoss_ma': last_trLoss_ma
                }, ckp_path)

        loss_epoch /= num_batch_processed

        config.writer.add_scalar('data/loss_epoch', loss_epoch, iterations - 1)
        for idx in range(len(tasks)):
            task = tasks[idx]
            loss_epoch_list[idx] /= num_batch_processed_list[idx]
            config.writer.add_scalar(
                'data/loss_epoch_idx{}_{}'.format(idx, task),
                loss_epoch_list[idx], iterations - 1)
        # import ipdb; ipdb.set_trace()

        ### lr decay
        trLoss_queue.append(loss_epoch)
        trLoss_ma = np.asarray(trLoss_queue).mean(
        )  # moving average. What about exponential moving average
        config.writer.add_scalar('data/trLoss_ma', trLoss_ma, iterations - 1)

        for idx in range(len(tasks)):
            task = tasks[idx]
            trLoss_queue_list[idx].append(loss_epoch_list[idx])
            trLoss_ma_list[idx] = np.asarray(trLoss_queue_list[idx]).mean(
            )  # moving average. What about exponential moving average
            config.writer.add_scalar(
                'data/trLoss_ma_idx{}_{}'.format(idx, task),
                trLoss_ma_list[idx], iterations - 1)

        # import ipdb; ipdb.set_trace()
        #### online eval
        Eval_bool = False
        if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0:
            Eval_bool = True
        elif lr < 1e-8:
            Eval_bool = True
            logger.info(
                'lr is reduced to {}. Will do the last evaluation for all samples!'
                .format(lr))

        else:
            pass
        # if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0:
        if Eval_bool:
            eval(args, tasks_archive, model, epoch, iterations - 1)

        ## stop if lr is too low
        if lr < 1e-8:
            logger.info('lr is reduced to {}. Job Done!'.format(lr))
            break

        ###### lr decay based on current task
        if len(trLoss_queue) == trLoss_queue.maxlen:
            if last_trLoss_ma and last_trLoss_ma - trLoss_ma < 1e-4:  # 5e-3
                lr /= 2
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            last_trLoss_ma = trLoss_ma

        ## save model when lr < 1e-8
        if lr < 1e-8:
            ckp_path = os.path.join(
                config.log_dir,
                '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode,
                                                  '_'.join(args.tasks), epoch,
                                                  tinies.datestr()))
            torch.save(
                {
                    'epoch': epoch,
                    'model': model,
                    'model_state_dict': model.state_dict(),
                    'optimizer': optimizer,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'trLoss_queue': trLoss_queue,
                    'last_trLoss_ma': last_trLoss_ma
                }, ckp_path)
 def __init__(self, alpha=10, weight=None):
     super(CombinedLoss, self).__init__()
     self.alpha = alpha
     self.dice_loss = MultiDiceLoss(weight)
     self.focal_loss = FocalLoss(weight)