Example #1
0
def main(args):
    heads = {
        'hm': 1,  # 1 channel Probability heat map.
        'wh': 3  # 3 channel x,y,z size regression.
    }
    model = get_large_hourglass_net(heads, n_stacks=1, debug=True)

    trainset = AbusNpyFormat(root=root)
    trainset_loader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=0)

    crit_hm = FocalLoss()
    crit_reg = RegL1Loss()
    crit_wh = crit_reg

    for batch_idx, (data_img, data_hm, data_wh,
                    _) in enumerate(trainset_loader):
        if use_cuda:
            data_img = data_img.cuda()
            data_hm = data_hm.cuda()
            data_wh = data_wh.cuda()
            model.to(device)

        output = model(data_img)

        wh_pred = torch.abs(output[-1]['wh'])
        hm_loss = crit_hm(output[-1]['hm'], data_hm)
        wh_loss = 100 * crit_wh(wh_pred, data_wh)

        print("hm_loss: %.3f, wh_loss: %.3f" \
                % (hm_loss.item(), wh_loss.item()))
        return
Example #2
0
    def __init__(self, config, model, trainLoader, testLoader):
        self.model = model
        self.trainLoader = trainLoader
        self.testLoader = testLoader
        self.n_classes = config.n_classes
        self.use_cuda = config.use_cuda

        self.optimizer = optim.SGD(self.model.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-4)
        self.criterion = FocalLoss(num_classes=self.n_classes)
        self.lr_scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=[6, 9], gamma=0.1)
        if self.use_cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        self.n_epochs = config.n_epochs
        self.log_step = config.log_step
        self.out_path = config.out_path
        self.best_loss = float('inf')
Example #3
0
def get_loss(args):

    if args.loss == "bce":
        loss_func = nn.BCEWithLogitsLoss()
    elif args.loss == "focal":
        if "focal_gamma" in args.loss_params:
            loss_func = FocalLoss(gamma=args.loss_params["focal_gamma"])
        else:
            loss_func = FocalLoss()
    elif args.loss == "fbeta":
        if "fbeta" in args.loss_params:
            loss_func = FBetaLoss(beta=args.loss_params["fbeta"], soft=True)
        else:
            loss_func = FBetaLoss(soft=True)
    elif args.loss == "softmargin":
        loss_func = nn.SoftMarginLoss()
    else:
        raise ValueError(f"Invalid loss function specifier: {args.loss}")

    return loss_func
Example #4
0
def setup_train_chain(cfg, model):
    # setup loc_loss
    if cfg.model.loc_loss == 'SmoothL1':
        loc_loss = SmoothL1()
    else:
        raise ValueError('Not support `loc_loss`: {}.'.format(
            cfg.model.loc_loss))

    # setup conf_loss
    if cfg.model.conf_loss == 'FocalLoss':
        conf_loss = FocalLoss(cfg.model.focal_loss_alpha,
                              cfg.model.focal_loss_gamma)
    elif cfg.model.conf_loss == 'SoftmaxCrossEntropy':
        conf_loss = SoftmaxCrossEntropy()
    else:
        raise ValueError('Not support `conf_loss`: {}.'.format(
            cfg.model.conf_loss))

    train_chain = RetinaNetTrainChain(model, loc_loss, conf_loss,
                                      cfg.model.fg_thresh, cfg.model.bg_thresh)
    return train_chain
Example #5
0
    def __init__(self, config: Config):
        super(PolyNet, self).__init__()
        self.config = config
        self.increase = nn.Conv2d(3, 10, kernel_size=3, stride=1, padding=1)
        self.pre = Hourglass(4, 10, increase=10)  # downsample 5 times
        self.focalLoss = FocalLoss()
        # self.heatmapLoss = HeatmapLossMSE(weight=100, config=self.config)
        self.heatmapLoss = HeatmapLossSL1()
        # self.criterion = nn.BCELoss()  # standard BCEloss
        self.location = nn.Sequential(
            Conv(kernel_size=3, inp_dim=10, out_dim=30, relu=True),
            nn.Conv2d(in_channels=30, out_channels=1, kernel_size=1, stride=1),
        )

        self.center = nn.Sequential(
            Conv(kernel_size=3, inp_dim=10, out_dim=30, relu=True),
            nn.Conv2d(in_channels=30, out_channels=1, kernel_size=1, stride=1),
        )

        self.xyoffsets = nn.Sequential(  #2 channels for x and y
            Conv(kernel_size=3, inp_dim=10, out_dim=30, relu=True),
            nn.Conv2d(in_channels=30, out_channels=2, kernel_size=1, stride=1),
        )
        # self.conv=Conv(inp_dim=3, out_dim=1, kernel_size=1, stride=1, bn=False, relu=False)

        self.location2 = nn.Sequential(
            Conv(kernel_size=3, inp_dim=14, out_dim=30, relu=True),
            nn.Conv2d(in_channels=30, out_channels=1, kernel_size=1, stride=1),
        )

        self.center2 = nn.Sequential(
            Conv(kernel_size=3, inp_dim=14, out_dim=30, relu=True),
            nn.Conv2d(in_channels=30, out_channels=1, kernel_size=1, stride=1),
        )

        self.xyoffsets2 = nn.Sequential(  #2 channels for x and y
            Conv(kernel_size=3, inp_dim=14, out_dim=30, relu=True),
            nn.Conv2d(in_channels=30, out_channels=2, kernel_size=1, stride=1),
        )
Example #6
0
def train():
    print("local_rank:", args.local_rank)
    cudnn.benchmark = True
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
    )
    torch.manual_seed(0)

    if not args.eval_net:
        train_ds = dataset_desc.Dataset('train')
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size=config.mini_batch_size, shuffle=False,
            drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True
        )

        val_ds = dataset_desc.Dataset('test')
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds)
        val_loader = torch.utils.data.DataLoader(
            val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
            drop_last=False, num_workers=4, sampler=val_sampler
        )
    else:
        test_ds = dataset_desc.Dataset('test')
        test_loader = torch.utils.data.DataLoader(
            test_ds, batch_size=config.test_mini_batch_size, shuffle=False,
            num_workers=20
        )

    rndla_cfg = ConfigRandLA
    model = FFB6D(
        n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg,
        n_kps=config.n_keypoints
    )
    model = convert_syncbn_model(model)
    device = torch.device('cuda:{}'.format(args.local_rank))
    print('local_rank:', args.local_rank)
    model.to(device)
    optimizer = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
    opt_level = args.opt_level
    model, optimizer = amp.initialize(
        model, optimizer, opt_level=opt_level,
    )

    # default value
    it = -1  # for the initialize value of `LambdaLR` and `BNMomentumScheduler`
    best_loss = 1e10
    start_epoch = 1

    # load status from checkpoint
    if args.checkpoint is not None:
        checkpoint_status = load_checkpoint(
            model, optimizer, filename=args.checkpoint[:-8]
        )
        if checkpoint_status is not None:
            it, start_epoch, best_loss = checkpoint_status
        if args.eval_net:
            assert checkpoint_status is not None, "Failed loadding model."

    if not args.eval_net:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank,
            find_unused_parameters=True
        )
        clr_div = 6
        lr_scheduler = CyclicLR(
            optimizer, base_lr=1e-5, max_lr=1e-3,
            cycle_momentum=False,
            step_size_up=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
            step_size_down=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
            mode='triangular'
        )
    else:
        lr_scheduler = None

    bnm_lmbd = lambda it: max(
        args.bn_momentum * args.bn_decay ** (int(it * config.mini_batch_size / args.decay_step)),
        bnm_clip,
    )
    bnm_scheduler = pt_utils.BNMomentumScheduler(
        model, bn_lambda=bnm_lmbd, last_epoch=it
    )

    it = max(it, 0)  # for the initialize value of `trainer.train`

    if args.eval_net:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2), OFLoss(),
            args.test,
        )
    else:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2).to(device), OFLoss().to(device),
            args.test,
        )

    checkpoint_fd = config.log_model_dir

    trainer = Trainer(
        model,
        model_fn,
        optimizer,
        checkpoint_name=os.path.join(checkpoint_fd, "FFB6D"),
        best_name=os.path.join(checkpoint_fd, "FFB6D_best"),
        lr_scheduler=lr_scheduler,
        bnm_scheduler=bnm_scheduler,
    )

    if args.eval_net:
        start = time.time()
        val_loss, res = trainer.eval_epoch(
            test_loader, is_test=True, test_pose=args.test_pose
        )
        end = time.time()
        print("\nUse time: ", end - start, 's')
    else:
        trainer.train(
            it, start_epoch, config.n_total_epoch, train_loader, None,
            val_loader, best_loss=best_loss,
            tot_iter=config.n_total_epoch * train_ds.minibatch_per_epoch // args.gpus,
            clr_div=clr_div
        )

        if start_epoch == config.n_total_epoch:
            _ = trainer.eval_epoch(val_loader)
def train(args):
    print('Preparing...')
    validset = AbusNpyFormat(root, crx_valid=True, crx_fold_num=args.crx_valid, crx_partition='valid', augmentation=False, include_fp=True)
    trainset = AbusNpyFormat(root, crx_valid=True, crx_fold_num=args.crx_valid, crx_partition='train', augmentation=True, include_fp=True)
    trainset_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    validset_loader = DataLoader(validset, batch_size=1, shuffle=False, num_workers=0)

    crit_hm = FocalLoss()
    crit_wh = RegL1Loss()

    train_hist = {
        'train_loss':[],
        'valid_hm_loss':[],
        'valid_wh_loss':[],
        'valid_total_loss':[],
        'per_epoch_time':[]
    }

    heads = {
        'hm': 1,
        'wh': 3,
        'fp_hm': 1
    }
    model = get_large_hourglass_net(heads, n_stacks=1)

    init_ep = 0
    end_ep = args.max_epoch
    print('Resume training from the designated checkpoint.')
    path = pre_dir + 'hourglass_' + 'f{}_frz'.format(args.crx_valid)
    pretrained_dict = torch.load(path)
    model_dict = model.state_dict()

    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict) 
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    if args.freeze:
        for param in model.pre.parameters():
            param.requires_grad = False
        for param in model.kps.parameters():
            param.requires_grad = False
        for param in model.cnvs.parameters():
            param.requires_grad = False
        for param in model.inters.parameters():
            param.requires_grad = False
        for param in model.inters_.parameters():
            param.requires_grad = False
        for param in model.cnvs_.parameters():
            param.requires_grad = False
        for param in model.wh.parameters():
            param.requires_grad = False
        for param in model.hm.parameters():
            param.requires_grad = False
        
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    optim_sched = ExponentialLR(optimizer, 0.92, last_epoch=-1)
    model.to(device)
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    print('Preparation done.')
    print('******************')
    print('Training starts...')

    start_time = time.time()
    min_loss = 0

    first_ep = True
    for epoch in range(init_ep, end_ep):
        train_loss = 0
        valid_hm_loss = 0
        epoch_start_time = time.time()
        lambda_s = args.lambda_s # * (1.03**epoch)

        # Training
        model.train()
        optimizer.zero_grad()
        for batch_idx, (data_img, data_hm, data_wh, _) in enumerate(trainset_loader):
            if use_cuda:
                data_img = data_img.cuda()
                data_hm = data_hm.cuda()
                data_wh = data_wh.cuda()
            output = model(data_img)
            hm_loss = crit_hm(output[-1]['fp_hm'], data_hm)

            total_loss = hm_loss
            train_loss += hm_loss.item()
            with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            if  (first_ep and batch_idx < 10) or ((batch_idx % 8) is 0) or (batch_idx == len(trainset_loader) - 1):
                print('Gradient applied at batch #', batch_idx)
                optimizer.step()
                optimizer.zero_grad()
            
            print("Epoch: [{:2d}] [{:3d}], hm_loss: {:.3f}"\
                .format((epoch + 1), (batch_idx + 1), hm_loss.item()))
        
        optim_sched.step()

        # Validation
        model.eval()
        with torch.no_grad():
            for batch_idx, (data_img, data_hm, data_wh, _) in enumerate(validset_loader):
                if use_cuda:
                    data_img = data_img.cuda()
                    data_hm = data_hm.cuda()
                    data_wh = data_wh.cuda()
                output = model(data_img)
                hm_loss = crit_hm(output[-1]['fp_hm'], data_hm)

                valid_hm_loss += hm_loss.item()

        valid_hm_loss = valid_hm_loss/validset.__len__()
        train_loss = train_loss/trainset.__len__()

        if epoch == 0 or valid_hm_loss < min_loss:
            min_loss = valid_hm_loss
            model.save(str(epoch))
        elif (epoch % 5) == 4:
            model.save(str(epoch))
        model.save('latest')

        train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
        train_hist['valid_hm_loss'].append(valid_hm_loss)
        train_hist['train_loss'].append(train_loss)
        plt.figure()
        plt.plot(train_hist['train_loss'], color='k')
        plt.plot(train_hist['valid_total_loss'], color='r')
        plt.plot(train_hist['valid_hm_loss'], color='b')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.savefig('loss_fold{}.png'.format(args.crx_valid))
        plt.close()

        print("Epoch: [{:d}], valid_hm_loss: {:.3f}".format((epoch + 1), valid_hm_loss))
        print('Epoch exec time: {} min'.format((time.time() - epoch_start_time)/60))
        first_ep = False

    print("Training finished.")
    print("Total time cost: {} min.".format((time.time() - start_time)/60))
Example #8
0
def train(args):
    checkpoint_dir = 'checkpoints/{}'.format(args.exp_name)
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    logger = setup_logger("CenterNet_ABUS", checkpoint_dir, distributed_rank=0)
    logger.info(args)

    logger.info('Preparing...')
    validset = AbusNpyFormat(testing_mode=0,
                             root=root,
                             crx_valid=True,
                             crx_fold_num=args.crx_valid,
                             crx_partition='valid',
                             augmentation=False)
    trainset = AbusNpyFormat(testing_mode=0,
                             root=root,
                             crx_valid=True,
                             crx_fold_num=args.crx_valid,
                             crx_partition='train',
                             augmentation=True)
    trainset_loader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=6)
    validset_loader = DataLoader(validset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=6)

    crit_hm = FocalLoss()
    crit_wh = RegL1Loss()

    train_hist = {
        'train_loss': [],
        'valid_hm_loss': [],
        'valid_wh_loss': [],
        'valid_total_loss': [],
        'per_epoch_time': []
    }

    heads = {'hm': 1, 'wh': 3}
    model = get_large_hourglass_net(heads, n_stacks=1)
    model = model.to(device)
    checkpointer = SimpleCheckpointer(checkpoint_dir, model)
    if args.resume:
        init_ep = 0
        logger.info('Resume training from the designated checkpoint.')
        checkpointer.load(str(args.resume_ep))
    else:
        init_ep = 0
    end_ep = args.max_epoch

    if args.freeze:
        logger.info('Paritially freeze layers.')
        for param in model.pre.parameters():
            param.requires_grad = False
        for param in model.kps.parameters():
            param.requires_grad = False
        for param in model.cnvs.parameters():
            param.requires_grad = False
        for param in model.inters.parameters():
            param.requires_grad = False
        for param in model.inters_.parameters():
            param.requires_grad = False
        for param in model.cnvs_.parameters():
            param.requires_grad = False
        for param in model.hm.parameters():
            param.requires_grad = False
        crit_wh = RegL2Loss()

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr)
    optim_sched = ExponentialLR(optimizer, 0.95, last_epoch=-1)

    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    logger.info('Preparation done.')
    logger.info('******************')
    logger.info('Training starts...')

    start_time = time.time()
    min_loss = 0

    checkpointer.save('initial')
    first_ep = True

    for epoch in range(init_ep, end_ep):
        epoch_start_time = time.time()
        train_loss = 0
        current_loss = 0
        valid_hm_loss = 0
        valid_wh_loss = 0
        lambda_s = args.lambda_s  # * (1.03**epoch)

        # Training
        model.train()
        optimizer.zero_grad()
        for batch_idx, (data_img, data_hm, data_wh,
                        _) in enumerate(trainset_loader):
            if use_cuda:
                data_img = data_img.cuda()
                data_hm = data_hm.cuda()
                data_wh = data_wh.cuda()
            output = model(data_img)
            hm_loss = crit_hm(output[-1]['hm'], data_hm)
            wh_loss = crit_wh(output[-1]['wh'], data_wh)

            total_loss = hm_loss + lambda_s * wh_loss
            train_loss += (hm_loss.item() + args.lambda_s * wh_loss.item())
            total_loss.backward()
            if (first_ep and batch_idx < 10) or ((batch_idx % 16) is 0) or (
                    batch_idx == len(trainset_loader) - 1):
                logger.info(
                    'Gradient applied at batch #{}  '.format(batch_idx))
                optimizer.step()
                optimizer.zero_grad()

            print("Epoch: [{:2d}] [{:3d}], hm_loss: {:.3f}, wh_loss: {:.3f}, total_loss: {:.3f}"\
                .format((epoch + 1), (batch_idx + 1), hm_loss.item(), wh_loss.item(), total_loss.item()))

        optim_sched.step()

        # Validation
        model.eval()
        with torch.no_grad():
            for batch_idx, (data_img, data_hm, data_wh,
                            _) in enumerate(validset_loader):
                if use_cuda:
                    data_img = data_img.cuda()
                    data_hm = data_hm.cuda()
                    data_wh = data_wh.cuda()
                output = model(data_img)
                hm_loss = crit_hm(output[-1]['hm'], data_hm)
                wh_loss = crit_wh(output[-1]['wh'], data_wh)

                valid_hm_loss += hm_loss.item()
                valid_wh_loss += wh_loss.item()

        valid_hm_loss = valid_hm_loss / validset.__len__()
        valid_wh_loss = valid_wh_loss / validset.__len__()
        train_loss = train_loss / trainset.__len__()
        current_loss = valid_hm_loss + args.lambda_s * valid_wh_loss

        save_id = (args.resume_ep + '_' +
                   str(epoch)) if args.resume else str(epoch)
        if epoch == 0 or current_loss < min_loss:
            min_loss = current_loss
            checkpointer.save(save_id)
        elif (epoch % 5) == 4:
            checkpointer.save(save_id)
        checkpointer.save('latest')

        train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
        train_hist['valid_hm_loss'].append(valid_hm_loss)
        train_hist['valid_wh_loss'].append(args.lambda_s * valid_wh_loss)
        train_hist['valid_total_loss'].append(current_loss)
        train_hist['train_loss'].append(train_loss)
        plt.figure()
        plt.plot(train_hist['train_loss'], color='k')
        plt.plot(train_hist['valid_total_loss'], color='r')
        plt.plot(train_hist['valid_hm_loss'], color='b')
        plt.plot(train_hist['valid_wh_loss'], color='g')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.savefig('loss_fold{}.png'.format(args.crx_valid))
        plt.close()
        np.save('train_hist_{}.npy'.format(args.exp_name), train_hist)
        logger.info(
            "Epoch: [{:d}], valid_hm_loss: {:.3f}, valid_wh_loss: {:.3f}".
            format((epoch + 1), valid_hm_loss, args.lambda_s * valid_wh_loss))
        logger.info('Epoch exec time: {} min'.format(
            (time.time() - epoch_start_time) / 60))
        first_ep = False

    logger.info("Training finished.")
    logger.info("Total time cost: {} min.".format(
        (time.time() - start_time) / 60))
Example #9
0
class Solver(object):
    def __init__(self, config, model, trainLoader, testLoader):
        self.model = model
        self.trainLoader = trainLoader
        self.testLoader = testLoader
        self.n_classes = config.n_classes
        self.use_cuda = config.use_cuda

        self.optimizer = optim.SGD(self.model.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-4)
        self.criterion = FocalLoss(num_classes=self.n_classes)
        self.lr_scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=[6, 9], gamma=0.1)
        if self.use_cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        self.n_epochs = config.n_epochs
        self.log_step = config.log_step
        self.out_path = config.out_path
        self.best_loss = float('inf')

    def train(self, epoch):
        print('\nEpoch: %d' % epoch)
        self.model.train()
        self.lr_scheduler.step()
        train_loss = 0
        for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(self.trainLoader):
            if self.use_cuda:
                inputs = Variable(inputs).cuda()
                loc_targets = Variable(loc_targets).cuda()
                cls_targets = Variable(cls_targets).cuda()

            self.optimizer.zero_grad()
            loc_preds, cls_preds = self.model(inputs)
            loss = self.criterion(loc_preds, loc_targets, cls_preds, cls_targets, change_alpha=True)
            loss.backward()
            self.optimizer.step()

            train_loss += float(loss.data[0])
            print('train_loss: %.3f | avg_loss: %.3f [%d/%d]'
                  % (loss.data[0], train_loss / (batch_idx + 1), batch_idx + 1, len(self.trainLoader)))

    def test(self, epoch):
        print('\nTest')
        self.model.eval()
        test_loss = 0
        for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(self.testLoader):
            if self.use_cuda:
                inputs = Variable(inputs).cuda()
                loc_targets = Variable(loc_targets).cuda()
                cls_targets = Variable(cls_targets).cuda()

            loc_preds, cls_preds = self.model(inputs)
            loss = self.criterion(loc_preds, loc_targets, cls_preds, cls_targets, change_alpha=False)
            test_loss += float(loss.data[0])
            print('test_loss: %.3f | avg_loss: %.3f [%d/%d]'
                  % (loss.data[0], test_loss / (batch_idx + 1), batch_idx + 1, len(self.testLoader)))

        # Save checkpoint
        test_loss /= len(self.testLoader)
        if test_loss < self.best_loss:
            print('Saving..')
            state = {
                'net': self.model.state_dict(),
                'loss': test_loss,
                'epoch': epoch,
            }
            if not os.path.isdir(os.path.dirname(config.checkpoint)):
                os.mkdir(os.path.dirname(config.checkpoint))
            torch.save(state, config.checkpoint)
            self.best_loss = test_loss
Example #10
0
        model.load_state_dict(torch.load(args.snapshot), strict=True)
    if args.cuda:
        model.cuda()
    print(model)

    # optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

    # BCE(Focal) loss applied to each pixel individually
    model.hm[2].bias.data.uniform_(-4.595,
                                   -4.595)  # bias towards negative class
    if args.loss_type == 'focal':
        criterion_1 = FocalLoss(gamma=2.0, alpha=0.25, size_average=True)
    elif args.loss_type == 'bce':
        ## BCE weight
        criterion_1 = torch.nn.BCEWithLogitsLoss()
    elif args.loss_type == 'wbce':
        ## BCE weight
        criterion_1 = torch.nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor([9.6]).cuda())
    criterion_2 = IoULoss()
    criterion_reg = RegL1Loss()

    # set up figures and axes
    fig1, ax1 = plt.subplots()
    plt.grid(True)
    ax1.plot([], 'r', label='Training segmentation loss')
    ax1.plot([], 'g', label='Training VAF loss')