コード例 #1
0
ファイル: inference.py プロジェクト: TienYiChi/CenterNet_ABUS
def main(args):
    size = (160, 40, 160)
 
    heads = {
        'hm': 1, # 1 channel Probability heat map.
        'wh': 3  # 3 channel x,y,z size regression.
    }
    model = get_large_hourglass_net(heads, n_stacks=1, debug=False)
    model.load(chkpts_dir, args.epoch)
    model = model.to(device)
    model.eval()

    trainset = AbusNpyFormat(root=root, crx_valid=True, crx_fold_num=args.fold_num, crx_partition='valid')
    trainset_loader = DataLoader(trainset, batch_size=1, shuffle=False, num_workers=0)

    start_time = time.time()
    with torch.no_grad():
        for batch_idx, (data_img, hm_gt, box_gt, extra_info) in enumerate(trainset_loader):
            f_name = trainset.getName(extra_info[1])
            data_img = data_img.to(device)
            output = model(data_img)
            print('***************************')
            print('Processing: ', f_name)
            wh_pred = torch.abs(output[-1]['wh'])
            hm_pred = output[-1]['hm']
            boxes = []
            # First round
            boxes = _get_topk(boxes, hm_pred, size, wh_pred, topk=50)

            boxes = np.array(boxes, dtype=float)
            np.save(os.path.join(npy_dir, f_name), boxes)

    print("Inference finished.")
    print("Average time cost: {:.2f} sec.".format((time.time() - start_time)/trainset.__len__()))
コード例 #2
0
def main(args):
    heads = {
        'hm': 1,  # 1 channel Probability heat map.
        'wh': 3  # 3 channel x,y,z size regression.
    }
    model = get_large_hourglass_net(heads, n_stacks=1, debug=True)

    trainset = AbusNpyFormat(root=root)
    trainset_loader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=0)

    crit_hm = FocalLoss()
    crit_reg = RegL1Loss()
    crit_wh = crit_reg

    for batch_idx, (data_img, data_hm, data_wh,
                    _) in enumerate(trainset_loader):
        if use_cuda:
            data_img = data_img.cuda()
            data_hm = data_hm.cuda()
            data_wh = data_wh.cuda()
            model.to(device)

        output = model(data_img)

        wh_pred = torch.abs(output[-1]['wh'])
        hm_loss = crit_hm(output[-1]['hm'], data_hm)
        wh_loss = 100 * crit_wh(wh_pred, data_wh)

        print("hm_loss: %.3f, wh_loss: %.3f" \
                % (hm_loss.item(), wh_loss.item()))
        return
コード例 #3
0
def main(args):
    all_data = AbusNpyFormat(root,
                             crx_valid=False,
                             augmentation=False,
                             include_fp=False)
    data, hm, wh, label = all_data.__getitem__(args.index)
    print('Dataset size:', all_data.__len__())
    print('Shape of data:', data.shape, hm.shape, wh.shape)

    data = data.detach().numpy()
    tmp_dir = os.path.join(os.path.dirname(__file__), 'test', str(args.index),
                           'hm')
    draw_slice(data[0], hm[0], tmp_dir, label=label[0])
    tmp_dir = os.path.join(os.path.dirname(__file__), 'test', str(args.index),
                           'wh_x')
    draw_slice(data[0], wh[2], tmp_dir, label=label[0])
    tmp_dir = os.path.join(os.path.dirname(__file__), 'test', str(args.index),
                           'wh_y')
    draw_slice(data[0], wh[1], tmp_dir, label=label[0])
    tmp_dir = os.path.join(os.path.dirname(__file__), 'test', str(args.index),
                           'wh_z')
    draw_slice(data[0], wh[0], tmp_dir, label=label[0])
    return
コード例 #4
0
def main(args):
    heads = {
        'hm': 1, # 1 channel Probability heat map.
        'wh': 3  # 3 channel x,y,z size regression.
    }
    model = get_large_hourglass_net(heads, n_stacks=1, debug=True)
    model = model.to(device)
    print(args.freeze)
    if args.freeze:
        for name, module in model._modules.items():
            print(name)
        print(type(model.parameters()))
        print(type(filter(lambda p: p.requires_grad, model.parameters())))
        for param in model.pre.parameters():
            param.requires_grad = False
        for param in model.kps.parameters():
            param.requires_grad = False
        for param in model.cnvs.parameters():
            param.requires_grad = False
        for param in model.inters.parameters():
            param.requires_grad = False
        for param in model.inters_.parameters():
            param.requires_grad = False
        for param in model.cnvs_.parameters():
            param.requires_grad = False
        for param in model.hm.parameters():
            param.requires_grad = False

    trainset = AbusNpyFormat(root=root, crx_valid=True, crx_fold_num=0, crx_partition='train', augmentation=True)
    trainset_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    for batch_idx, (data_img, hm_gt, box_gt, _) in enumerate(trainset_loader):
        data_img = data_img.to(device)
        print('Batch number:', len(trainset_loader))
        output = model(data_img)
        print('Output length:', len(output))
        print('HM tensor:', output[-1]['hm'].shape)
        print('Box tensor:', output[-1]['wh'].shape)
        print('GT HM tensor:', hm_gt.shape)
        print('GT Box tensor:', box_gt.shape)
        return
コード例 #5
0
def train(args):
    print('Preparing...')
    validset = AbusNpyFormat(root, crx_valid=True, crx_fold_num=args.crx_valid, crx_partition='valid', augmentation=False, include_fp=True)
    trainset = AbusNpyFormat(root, crx_valid=True, crx_fold_num=args.crx_valid, crx_partition='train', augmentation=True, include_fp=True)
    trainset_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    validset_loader = DataLoader(validset, batch_size=1, shuffle=False, num_workers=0)

    crit_hm = FocalLoss()
    crit_wh = RegL1Loss()

    train_hist = {
        'train_loss':[],
        'valid_hm_loss':[],
        'valid_wh_loss':[],
        'valid_total_loss':[],
        'per_epoch_time':[]
    }

    heads = {
        'hm': 1,
        'wh': 3,
        'fp_hm': 1
    }
    model = get_large_hourglass_net(heads, n_stacks=1)

    init_ep = 0
    end_ep = args.max_epoch
    print('Resume training from the designated checkpoint.')
    path = pre_dir + 'hourglass_' + 'f{}_frz'.format(args.crx_valid)
    pretrained_dict = torch.load(path)
    model_dict = model.state_dict()

    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict) 
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    if args.freeze:
        for param in model.pre.parameters():
            param.requires_grad = False
        for param in model.kps.parameters():
            param.requires_grad = False
        for param in model.cnvs.parameters():
            param.requires_grad = False
        for param in model.inters.parameters():
            param.requires_grad = False
        for param in model.inters_.parameters():
            param.requires_grad = False
        for param in model.cnvs_.parameters():
            param.requires_grad = False
        for param in model.wh.parameters():
            param.requires_grad = False
        for param in model.hm.parameters():
            param.requires_grad = False
        
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    optim_sched = ExponentialLR(optimizer, 0.92, last_epoch=-1)
    model.to(device)
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    print('Preparation done.')
    print('******************')
    print('Training starts...')

    start_time = time.time()
    min_loss = 0

    first_ep = True
    for epoch in range(init_ep, end_ep):
        train_loss = 0
        valid_hm_loss = 0
        epoch_start_time = time.time()
        lambda_s = args.lambda_s # * (1.03**epoch)

        # Training
        model.train()
        optimizer.zero_grad()
        for batch_idx, (data_img, data_hm, data_wh, _) in enumerate(trainset_loader):
            if use_cuda:
                data_img = data_img.cuda()
                data_hm = data_hm.cuda()
                data_wh = data_wh.cuda()
            output = model(data_img)
            hm_loss = crit_hm(output[-1]['fp_hm'], data_hm)

            total_loss = hm_loss
            train_loss += hm_loss.item()
            with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            if  (first_ep and batch_idx < 10) or ((batch_idx % 8) is 0) or (batch_idx == len(trainset_loader) - 1):
                print('Gradient applied at batch #', batch_idx)
                optimizer.step()
                optimizer.zero_grad()
            
            print("Epoch: [{:2d}] [{:3d}], hm_loss: {:.3f}"\
                .format((epoch + 1), (batch_idx + 1), hm_loss.item()))
        
        optim_sched.step()

        # Validation
        model.eval()
        with torch.no_grad():
            for batch_idx, (data_img, data_hm, data_wh, _) in enumerate(validset_loader):
                if use_cuda:
                    data_img = data_img.cuda()
                    data_hm = data_hm.cuda()
                    data_wh = data_wh.cuda()
                output = model(data_img)
                hm_loss = crit_hm(output[-1]['fp_hm'], data_hm)

                valid_hm_loss += hm_loss.item()

        valid_hm_loss = valid_hm_loss/validset.__len__()
        train_loss = train_loss/trainset.__len__()

        if epoch == 0 or valid_hm_loss < min_loss:
            min_loss = valid_hm_loss
            model.save(str(epoch))
        elif (epoch % 5) == 4:
            model.save(str(epoch))
        model.save('latest')

        train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
        train_hist['valid_hm_loss'].append(valid_hm_loss)
        train_hist['train_loss'].append(train_loss)
        plt.figure()
        plt.plot(train_hist['train_loss'], color='k')
        plt.plot(train_hist['valid_total_loss'], color='r')
        plt.plot(train_hist['valid_hm_loss'], color='b')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.savefig('loss_fold{}.png'.format(args.crx_valid))
        plt.close()

        print("Epoch: [{:d}], valid_hm_loss: {:.3f}".format((epoch + 1), valid_hm_loss))
        print('Epoch exec time: {} min'.format((time.time() - epoch_start_time)/60))
        first_ep = False

    print("Training finished.")
    print("Total time cost: {} min.".format((time.time() - start_time)/60))
コード例 #6
0
ファイル: trainer.py プロジェクト: brucetusec/CenterNet_ABUS
def train(args):
    checkpoint_dir = 'checkpoints/{}'.format(args.exp_name)
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    logger = setup_logger("CenterNet_ABUS", checkpoint_dir, distributed_rank=0)
    logger.info(args)

    logger.info('Preparing...')
    validset = AbusNpyFormat(testing_mode=0,
                             root=root,
                             crx_valid=True,
                             crx_fold_num=args.crx_valid,
                             crx_partition='valid',
                             augmentation=False)
    trainset = AbusNpyFormat(testing_mode=0,
                             root=root,
                             crx_valid=True,
                             crx_fold_num=args.crx_valid,
                             crx_partition='train',
                             augmentation=True)
    trainset_loader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=6)
    validset_loader = DataLoader(validset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=6)

    crit_hm = FocalLoss()
    crit_wh = RegL1Loss()

    train_hist = {
        'train_loss': [],
        'valid_hm_loss': [],
        'valid_wh_loss': [],
        'valid_total_loss': [],
        'per_epoch_time': []
    }

    heads = {'hm': 1, 'wh': 3}
    model = get_large_hourglass_net(heads, n_stacks=1)
    model = model.to(device)
    checkpointer = SimpleCheckpointer(checkpoint_dir, model)
    if args.resume:
        init_ep = 0
        logger.info('Resume training from the designated checkpoint.')
        checkpointer.load(str(args.resume_ep))
    else:
        init_ep = 0
    end_ep = args.max_epoch

    if args.freeze:
        logger.info('Paritially freeze layers.')
        for param in model.pre.parameters():
            param.requires_grad = False
        for param in model.kps.parameters():
            param.requires_grad = False
        for param in model.cnvs.parameters():
            param.requires_grad = False
        for param in model.inters.parameters():
            param.requires_grad = False
        for param in model.inters_.parameters():
            param.requires_grad = False
        for param in model.cnvs_.parameters():
            param.requires_grad = False
        for param in model.hm.parameters():
            param.requires_grad = False
        crit_wh = RegL2Loss()

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr)
    optim_sched = ExponentialLR(optimizer, 0.95, last_epoch=-1)

    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    logger.info('Preparation done.')
    logger.info('******************')
    logger.info('Training starts...')

    start_time = time.time()
    min_loss = 0

    checkpointer.save('initial')
    first_ep = True

    for epoch in range(init_ep, end_ep):
        epoch_start_time = time.time()
        train_loss = 0
        current_loss = 0
        valid_hm_loss = 0
        valid_wh_loss = 0
        lambda_s = args.lambda_s  # * (1.03**epoch)

        # Training
        model.train()
        optimizer.zero_grad()
        for batch_idx, (data_img, data_hm, data_wh,
                        _) in enumerate(trainset_loader):
            if use_cuda:
                data_img = data_img.cuda()
                data_hm = data_hm.cuda()
                data_wh = data_wh.cuda()
            output = model(data_img)
            hm_loss = crit_hm(output[-1]['hm'], data_hm)
            wh_loss = crit_wh(output[-1]['wh'], data_wh)

            total_loss = hm_loss + lambda_s * wh_loss
            train_loss += (hm_loss.item() + args.lambda_s * wh_loss.item())
            total_loss.backward()
            if (first_ep and batch_idx < 10) or ((batch_idx % 16) is 0) or (
                    batch_idx == len(trainset_loader) - 1):
                logger.info(
                    'Gradient applied at batch #{}  '.format(batch_idx))
                optimizer.step()
                optimizer.zero_grad()

            print("Epoch: [{:2d}] [{:3d}], hm_loss: {:.3f}, wh_loss: {:.3f}, total_loss: {:.3f}"\
                .format((epoch + 1), (batch_idx + 1), hm_loss.item(), wh_loss.item(), total_loss.item()))

        optim_sched.step()

        # Validation
        model.eval()
        with torch.no_grad():
            for batch_idx, (data_img, data_hm, data_wh,
                            _) in enumerate(validset_loader):
                if use_cuda:
                    data_img = data_img.cuda()
                    data_hm = data_hm.cuda()
                    data_wh = data_wh.cuda()
                output = model(data_img)
                hm_loss = crit_hm(output[-1]['hm'], data_hm)
                wh_loss = crit_wh(output[-1]['wh'], data_wh)

                valid_hm_loss += hm_loss.item()
                valid_wh_loss += wh_loss.item()

        valid_hm_loss = valid_hm_loss / validset.__len__()
        valid_wh_loss = valid_wh_loss / validset.__len__()
        train_loss = train_loss / trainset.__len__()
        current_loss = valid_hm_loss + args.lambda_s * valid_wh_loss

        save_id = (args.resume_ep + '_' +
                   str(epoch)) if args.resume else str(epoch)
        if epoch == 0 or current_loss < min_loss:
            min_loss = current_loss
            checkpointer.save(save_id)
        elif (epoch % 5) == 4:
            checkpointer.save(save_id)
        checkpointer.save('latest')

        train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
        train_hist['valid_hm_loss'].append(valid_hm_loss)
        train_hist['valid_wh_loss'].append(args.lambda_s * valid_wh_loss)
        train_hist['valid_total_loss'].append(current_loss)
        train_hist['train_loss'].append(train_loss)
        plt.figure()
        plt.plot(train_hist['train_loss'], color='k')
        plt.plot(train_hist['valid_total_loss'], color='r')
        plt.plot(train_hist['valid_hm_loss'], color='b')
        plt.plot(train_hist['valid_wh_loss'], color='g')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.savefig('loss_fold{}.png'.format(args.crx_valid))
        plt.close()
        np.save('train_hist_{}.npy'.format(args.exp_name), train_hist)
        logger.info(
            "Epoch: [{:d}], valid_hm_loss: {:.3f}, valid_wh_loss: {:.3f}".
            format((epoch + 1), valid_hm_loss, args.lambda_s * valid_wh_loss))
        logger.info('Epoch exec time: {} min'.format(
            (time.time() - epoch_start_time) / 60))
        first_ep = False

    logger.info("Training finished.")
    logger.info("Total time cost: {} min.".format(
        (time.time() - start_time) / 60))