예제 #1
0
def main():
    CUDA_OK = torch.cuda.is_available()
    args = parse()
    dl = DataLoader()
    train_iter, valid_iter = dl.load_translation(
        data_path=args.data_path,
        exts=('.' + args.src, '.' + args.tgt),  # ('.zh', '.en')
        batch_size=args.batch_size,
        dl_save_path=args.dl_path)

    args.n_src_words, args.n_tgt_words = len(dl.SRC.vocab), len(dl.TGT.vocab)
    args.src_pdx, args.tgt_pdx = dl.src_padding_index, dl.tgt_padding_index
    print(args)

    model = build_model(args, cuda_ok=CUDA_OK)
    trainer = Trainer(args,
                      model=model,
                      optimizer=torch.optim.Adam(model.parameters(),
                                                 lr=1e-3,
                                                 betas=(0.9, 0.98),
                                                 eps=1e-9),
                      criterion=nn.CrossEntropyLoss(ignore_index=args.tgt_pdx,
                                                    reduction='mean'),
                      cuda_ok=CUDA_OK)
    trainer.train(train_iter,
                  valid_iter,
                  n_epochs=args.n_epochs,
                  save_path=args.ckpt_path)
def main(**kwargs):

    # opt config
    opt._parse(kwargs)

    # Data Loader
    data_loader = DLoader(opt)
    train_loader, test_loader, test_video = data_loader.run()

    # Train my model
    model = Resnet2D(opt, train_loader, test_loader, test_video)
    model.run()
예제 #3
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    # setup networks
    Network = getattr(models, args.net)
    model = Network(**args.net_params)
    model = model.cuda()
    #criterion = getattr(criterions, args.criterion)
    criterion = F.cross_entropy

    model_file = os.path.join(ckpts, 'model_last.tar')
    checkpoint = torch.load(model_file)
    model.load_state_dict(checkpoint['state_dict'])

    valid_list = os.path.join(args.data_dir, 'file_list.txt')
    Dataset = getattr(datasets, args.dataset)

    repeat = False
    valid_dir = args.data_dir
    valid_set = Dataset(valid_list, root=valid_dir, for_train=False)
    valid_loader = DataLoader(valid_set,
                              batch_size=1,
                              shuffle=False,
                              collate_fn=valid_set.collate,
                              num_workers=2,
                              pin_memory=True,
                              prefetch=repeat)

    start = time.time()
    model.eval()
    with torch.no_grad():
        scores = validate(valid_loader, model, criterion, out_dir,
                          valid_set.names)

    msg = 'total time {:.4f} minutes'.format((time.time() - start) / 60)
    logging.info(msg)
    def __init__(self):
        self._opt = TestOptions().parse()
        assert len(self._opt.checkpoint_path) !=0 and os.path.exists(self._opt.checkpoint_path), "checkpoint_path does not exist."
        df = DatasetFactory()
        val_dataset = df.get_by_name(self._opt, 'Validation')
        self.val_loader = DataLoader(val_dataset,
                                 batch_size = self._opt.batch_size,
                                 shuffle = False, 
                                 drop_last = False, 
                                 num_workers=self._opt.n_threads_test)
        test_dataset = df.get_by_name(self._opt, 'Test')
        self.test_loader = DataLoader(test_dataset,
                                 batch_size = self._opt.batch_size,
                                 shuffle = True, 
                                 drop_last = True, 
                                 num_workers=self._opt.n_threads_test)
        self.num_classes = test_dataset.n_classes
        self.model = ModelFactory().get_by_name(self._opt.model_name, self._opt, self.num_classes)

        if self._opt.num_gpus > 1:
            self.model = make_parallel(self.model, self.num_gpus)
예제 #5
0
def get_dataloader(cfg: object, mode: str) -> tuple:
    """Get dataloader function

    This is function to get dataloaders.
    Get dataset, then make dataloaders.

    Args:
        cfg: Config.
        mode: Mode. 
            trainval: For trainning and validation.
            test: For test.

    Returns:
        Tuple of dataloaders.

    """

    log.info(f"Loading {cfg.data.dataset.name} dataset...")

    dataset = get_dataset(cfg, mode)
    sampler = get_sampler(cfg, mode, dataset)

    if mode == "trainval":
        train_dataloader = DataLoader(cfg,
                                      dataset=dataset.train,
                                      sampler=sampler.train)
        val_dataloader = DataLoader(cfg,
                                    dataset=dataset.val,
                                    sampler=sampler.val)
        dataloaders = (train_dataloader, val_dataloader)

    elif mode == "test":
        test_dataloader = DataLoader(cfg,
                                     dataset=dataset.test,
                                     sampler=sampler.test)
        dataloaders = (test_dataloader)

    log.info(f"Successfully loaded {cfg.data.dataset.name} dataset.")

    return dataloaders
예제 #6
0
def get_dataloader(cfg: object, mode: str) -> tuple:
    """Get dataloader function

    This is function to get dataloaders.
    Get dataset, then make dataloaders.

    Args:
        cfg: Config.
        mode: Mode. 
            trainval: For trainning and validation.
            test: For test.

    Returns:
        Tuple of dataloaders.

    """

    log.info(f"Loading {cfg.data.dataset.name} dataset...")

    if mode == "trainval":
        train_dataset, val_dataset = get_dataset(cfg, mode="trainval")
        train_sampler = get_sampler(cfg, mode="train", dataset=train_dataset)
        val_sampler = get_sampler(cfg, mode="val", dataset=val_dataset)
        train_dataloader = DataLoader(cfg, dataset=train_dataset, sampler=train_sampler)
        val_dataloader = DataLoader(cfg, dataset=val_dataset, sampler=val_sampler)

        log.info(f"Successfully loaded {cfg.data.dataset.name} dataset.")

        return train_dataloader, val_dataloader

    elif mode == "test":
        test_dataset = get_dataset(cfg, mode="test")
        test_sampler = get_sampler(cfg, mode="test", dataset=test_dataset)
        test_dataloader = DataLoader(cfg, dataset=test_dataset, sampler=test_sampler)

        log.info(f"Successfully loaded {cfg.data.dataset.name} dataset.")
        
        return test_dataloader
    def __init__(self):
        self._opt = TrainOptions().parse()
        df = DatasetFactory()
        train_dataset = df.get_by_name(self._opt, 'Train')
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=self._opt.batch_size,
                                       shuffle=True,
                                       drop_last=True,
                                       num_workers=self._opt.n_threads_train)
        val_dataset = df.get_by_name(self._opt, 'Validation')
        self.val_loader = DataLoader(val_dataset,
                                     batch_size=self._opt.batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=self._opt.n_threads_test)
        self.num_classes = train_dataset.n_classes
        self.model = ModelFactory().get_by_name(self._opt.model_name,
                                                self._opt, self.num_classes)

        if self._opt.reg_func is not None:
            self.model = set_model_regularization(self.model,
                                                  self._opt.reg_func,
                                                  self._opt.reg_layers,
                                                  self._opt.reg_bias)
        if self._opt.num_gpus > 1:
            self.model = make_parallel(self.model, self.num_gpus)
        if self._opt.optimizer == 'Adam':
            #self.optimizer = optimizers.Adam(learning_rate=self._opt.init_lr, #keras 2.3+
            self.optimizer = optimizers.Adam(
                lr=self._opt.init_lr,  # keras 2.2.0
                beta_1=self._opt.adam_b1,
                beta_2=self._opt.adam_b2)
        elif self._opt.optimizer == 'SGD':
            self.optimizer = optimizers.SGD(lr=self._opt.init_lr,
                                            momentum=self._opt.momentum,
                                            nesterov=True)

        self.callback_list = list()
예제 #8
0
    def setup(self, bottom, top):
        """
        img_root:以及图片 存放路径 train.txt,val.txt,test.txt存放路径
        phase:train  or  val  or  test
        batch_size:batch_size
        random:是否shuffle
        
        """

        params = eval(self.param_str)
        self.img_root = params['img_root']
        self.phase = params['phase']
        self.random = params.get('randomize', True)
        self.crop_size = params.get('crop_size', [48, 48, 48])
        self.batch_size = params['batch_size']
        if self.phase not in ['train', 'test', 'val']:
            raise Exception("phase must be train  or  val  or  test")
        # two tops: data and label
        if len(top) != 2:
            raise Exception("Need to define two tops: data and label.")
        # data layers have no bottoms
        if len(bottom) != 0:
            raise Exception("Do not define a bottom.")
        self.data = {}
        top[0].reshape(self.batch_size, 1, *self.crop_size)
        top[1].reshape(self.batch_size, 1, *self.crop_size)
        # load indices for images and labels
        split_f = '{}txtfiles/{}.txt'.format(self.img_root, self.phase)
        self.indices = open(split_f, 'r').read().splitlines()
        #self.dataset=Dataset(self.indices,augment=True,crop_size=self.crop_size,randn=15)
        self.dataset = Dataset(self.indices,
                               augment=True,
                               crop_size=self.crop_size,
                               randn=7)
        self.dataloader = iter(
            DataLoader(self.dataset,
                       batch_size=self.batch_size,
                       shuffle=True,
                       num_workers=1,
                       drop_last=True))
        #import ipdb;ipdb.set_trace()
        self.idx = 0
        self.load_image = load_ct
        # make eval deterministic
        if 'train' not in self.phase:
            self.random = False
        # randomization: seed and pick
        if self.random:
            random.shuffle(self.indices)
예제 #9
0
def main():
    # setup environments and seeds
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # setup networks
    Network = getattr(models, args.net)
    model = Network(**args.net_params)
    model = model.cuda()

    optimizer = getattr(torch.optim, args.opt)(model.parameters(),
                                               **args.opt_params)
    criterion = getattr(criterions, args.criterion)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optim_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code

    Dataset = getattr(datasets, args.dataset)

    # The loader will get 1000 patches from 50 subjects for each sub epoch
    # each subject sample 20 patches
    train_list = os.path.join(args.data_dir, 'file_list.txt')
    train_set = Dataset(train_list,
                        root=args.data_dir,
                        for_train=True,
                        sample_size=args.patch_per_sample)

    num_iters = args.num_iters or len(train_set) * args.num_epochs
    num_iters -= args.start_iter

    train_sampler = SSampler(len(train_set), num_iters)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              collate_fn=train_set.collate,
                              sampler=train_sampler,
                              num_workers=args.workers,
                              pin_memory=True,
                              prefetch=False)

    #repeat = False
    #valid_dir = args.data_dir
    #valid_list = os.path.join(args.data_dir, 'valid_list.txt')
    #valid_set = Dataset(valid_list, root=args.data_dir, for_train=False)
    #valid_loader = DataLoader(
    #    valid_set,
    #    batch_size=1, shuffle=False,
    #    collate_fn=valid_set.collate,
    #    num_workers=2, pin_memory=True, prefetch=repeat)

    logging.info('-------------- New training session ----------------')

    start = time.time()

    enum_batches = len(train_set) / args.batch_size

    args.schedule = {
        int(k * enum_batches): v
        for k, v in args.schedule.items()
    }
    args.save_freq = int(enum_batches * args.save_freq)

    ## this is 0? in their configuration file
    #stop_class_balancing = int(args.stop_class_balancing * enum_batches)
    #weight = None

    losses = AverageMeter()
    torch.set_grad_enabled(True)

    for i, (data, label) in enumerate(train_loader, args.start_iter):

        adjust_learning_rate(optimizer, i)

        # look at dataset class
        #weight = data.pop()
        #if i < stop_class_balancing:
        #    alpha = float(i)/stop_class_balancing
        #    weight = alpha + (1.0 - alpha)*weight # alpha*y2 + (1.0 - alpha)*y1
        #    weight = weight.cuda(non_blocking=True)
        #else:
        #    weight = None
        #data = [d.cuda(non_blocking=True) for d in data]

        for x1, x2, target in zip(*[d.split(args.batch_size) for d in data]):

            x1, x2, target = [
                t.cuda(non_blocking=True) for t in (x1, x2, target)
            ]

            # compute output
            output = model((x1, x2))  # nx5x9x9x9, target nx9x9x9
            loss = criterion(output, target)
            #loss = criterion(output, target, weight)

            # measure accuracy and record loss
            losses.update(loss.item(), target.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if (i + 1) % args.save_freq == 0:
            file_name = os.path.join(ckpts, 'model_iter_{}.tar'.format(i + 1))
            torch.save(
                {
                    'iter': i + 1,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                }, file_name)

        msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.4f}'.format(
            i + 1, (i + 1) / enum_batches, losses.avg)
        logging.info(msg)

        losses.reset()

    file_name = os.path.join(ckpts, 'model_last.tar')
    torch.save(
        {
            'iter': i + 1,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        }, file_name)

    msg = 'total time: {} minutes'.format((time.time() - start) / 60)
    logging.info(msg)
예제 #10
0
                                            Config.image_channels))
    sample_outputs = network(sample_inputs, training=True)
    network.summary()


if __name__ == '__main__':
    # GPU settings
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    # dataset
    train_dataset = DetectionDataset()
    train_data, train_size = train_dataset.generate_datatset()
    data_loader = DataLoader()
    steps_per_epoch = tf.math.ceil(train_size / Config.batch_size)

    # model
    centernet = CenterNet()
    print_model_summary(centernet)
    load_weights_from_epoch = Config.load_weights_from_epoch
    if Config.load_weights_before_training:
        centernet.load_weights(filepath=Config.save_model_dir +
                               "epoch-{}".format(load_weights_from_epoch))
        print("Successfully load weights!")
    else:
        load_weights_from_epoch = -1

    # optimizer
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
예제 #11
0
        im2 = np.fliplr(im2)

    angle = random.choice([0, 1, 2, 3])
    im1 = np.rot90(im1, angle)
    im2 = np.rot90(im2, angle)

    return im1.copy(), im2.copy()


if __name__ == '__main__':
    from config.config_lol_my_data_ps_pair_3 import cfg
    train_set = TrainDataset(None, None, cfg=cfg)
    print(len(train_set))
    train_loader = DataLoader(train_set,
                              batch_size=64,
                              num_workers=16,
                              shuffle=False,
                              drop_last=True,
                              timeout=0)

    valid_hr_file = '/workspace/nas_mengdongwei/dataset/div2k/div2k_valid_hr_paths.txt'
    valid_lr_file = '/workspace/nas_mengdongwei/dataset/div2k/div2k_valid_lr_paths.txt'
    test_set = TestDataset(valid_hr_file, valid_hr_file)
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=16,
                             shuffle=False,
                             drop_last=True,
                             timeout=0)
    for index, sample in enumerate(test_loader):
        lr, hr, cubic = sample
        print(hr.shape)
예제 #12
0
def train():
    #开启GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    # 读取数据
    dataloader = DetectionDataset()
    train_data, train_size = dataloader.generate_datatset(
        mode="train")  #获取txt文件中的string数据,返回的是batch_size的string,以及txt的大小
    train_steps_per_epoch = tf.math.ceil(train_size / Config.batch_size)
    #验证集
    val_data, val_size = dataloader.generate_datatset(mode="val")
    val_steps_per_epoch = tf.math.ceil(val_size / Config.batch_size)
    data_loader = DataLoader()  #创建一个class的大小数据加载

    if os.path.exists(Config.log_dir):
        # 清除summary目录下原有的东西
        shutil.rmtree(Config.log_dir)

    # 建立模型保存目录
    if not os.path.exists(os.path.split(Config.save_model_path)[0]):
        os.mkdir(os.path.split(Config.save_model_path)[0])

    print(
        'Total on {}, train on {} samples, val on {} samples with batch size {}.'
        .format((train_size + val_size), train_size, val_size,
                Config.batch_size))
    # optimizer
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-4,
        decay_steps=train_steps_per_epoch * Config.learning_rate_decay_epochs,
        decay_rate=0.96)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # 创建模型结构
    centernet = CenterNet()
    print_model_summary(centernet)
    try:
        centernet.load_weights(filepath=Config.save_model_path)
        print("Load weights...")
    except:
        print("load weights...")

    # 定义模型评估指标
    train_loss = tf.metrics.Mean(name='train_loss')
    valid_loss = tf.metrics.Mean(name='valid_loss')

    post_process = PostProcessing()
    #设置保存最好模型的指标
    best_test_loss = float('inf')

    # 创建summary
    summary_writer = tf.summary.create_file_writer(logdir=Config.log_dir)

    #训练
    for epoch in range(1, Config.epochs + 1):
        train_loss.reset_states()
        valid_loss.reset_states()

        #处理训练集数据
        for step, batch_data in enumerate(train_data):
            step_start_time = time.time()
            images, labels = data_loader.read_batch_data(
                batch_data
            )  # 返回的是图片image,以及标签信息[batch, max_boxes_per_image, xmin, ymin, xmax, ymax, class_id]
            with tf.GradientTape() as tape:
                # 得到预测
                pred = centernet(images, training=True)
                # 计算损失
                loss_value = post_process.training_procedure(
                    batch_labels=labels, pred=pred)

            # 反向传播梯度下降
            # model.trainable_variables代表把loss反向传播到每个可以训练的变量中
            gradients = tape.gradient(target=loss_value,
                                      sources=centernet.trainable_variables)
            # 将每个节点的误差梯度gradients,用于更新该节点的可训练变量值
            # zip是把梯度和可训练变量值打包成元组
            optimizer.apply_gradients(
                grads_and_vars=zip(gradients, centernet.trainable_variables))

            # 更新train_loss
            train_loss.update_state(values=loss_value)

            step_end_time = time.time()
            print("Epoch: {}/{}, step: {}/{}, loss: {}, time_cost: {:.3f}s".
                  format(epoch, Config.epochs, step, train_steps_per_epoch,
                         train_loss.result(), step_end_time - step_start_time))

            with summary_writer.as_default():
                tf.summary.scalar(
                    "steps_perbatch_train_loss",
                    train_loss.result(),
                    step=tf.cast(((epoch - 1) * train_steps_per_epoch + step),
                                 tf.int64))

        # 计算验证集
        for step, batch_data in enumerate(val_data):
            step_start_time = time.time()
            images, labels = data_loader.read_batch_data(
                batch_data
            )  # 返回的是图片image,以及标签信息[batch, max_boxes_per_image, xmin, ymin, xmax, ymax, class_id]
            # 得到预测,不training
            pred = centernet(images)
            # 计算损失
            loss_value = post_process.training_procedure(batch_labels=labels,
                                                         pred=pred)

            # 更新valid_loss
            valid_loss.update_state(loss_value)
            step_end_time = time.time()
            print(
                "--------Epoch: {}/{}, step: {}/{}, loss: {}, time_cost: {:.3f}s"
                .format(epoch, Config.epochs, step, val_steps_per_epoch,
                        valid_loss.result(), step_end_time - step_start_time))
            with summary_writer.as_default():
                tf.summary.scalar("steps_perbatch_val_loss",
                                  valid_loss.result(),
                                  step=tf.cast(
                                      (epoch - 1) * val_steps_per_epoch + step,
                                      tf.int64))

        # 保存到tensorboard里
        with summary_writer.as_default():
            tf.summary.scalar("train_loss",
                              train_loss.result(),
                              step=optimizer.iterations)
            tf.summary.scalar('valid_loss',
                              valid_loss.result(),
                              step=optimizer.iterations)

        # 只保存最好模型
        if valid_loss.result() < best_test_loss:
            best_test_loss = valid_loss.result()
            centernet.save_weights(Config.save_model_path, save_format="tf")
            print("Update model's weights")
예제 #13
0
def main():
    USER_DATA_ROOT = '../user_data'
    if len(args.store_name) == 0:
        args.store_name = '_'.join([
            'optim:{}'.format(args.optim),
            'batch_size:{}'.format(args.batch_size),
            'hidden_units:{}'.format(args.hidden_units)
        ])
    setattr(args, 'save_root', os.path.join(USER_DATA_ROOT, args.store_name))
    print("save experiment to :{}".format(args.save_root))
    check_rootfolders()
    num_class = 53
    setattr(args, 'num_class', num_class)
    pos_weight = torch.ones(
        [num_class]) * 25  # here we approximately set the pos weight as 25
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    ###########################  Create the classifier ###################
    model = MLP_RNN(args.hidden_units, args.num_class)
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print("Total Params: {}".format(pytorch_total_params))
    if args.gpus is not None:
        if len(args.gpus) != 1:
            model = nn.DataParallel(model)
        model.cuda()
        pos_weight = pos_weight.cuda()
    if args.optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)
    train_loader, val_loader = DataLoader(
        args.batch_size,
        lmdb_file=args.lmdb,
        val_ratio=args.val_ratio,
        train_num_workers=args.workers,
        val_num_workers=args.workers).create_dataloaders()
    log = open(
        os.path.join(args.save_root, args.root_log,
                     '{}.txt'.format(args.store_name)), 'w')
    best_loss = 1000
    val_accum_epochs = 0
    best_acc = 0
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)
        train(train_loader, model, criterion, optimizer, epoch, log)
        torch.cuda.empty_cache()
        if val_loader is None:
            # save every epoch model
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                },
                True,
                False,
                filename='MLP_RNN_{}'.format(epoch))
        elif (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            loss_val, acc_val = validate(val_loader, model, criterion,
                                         (epoch + 1) * len(train_loader), log)
            is_best_loss = loss_val < best_loss
            best_loss = min(loss_val, best_loss)
            is_best_acc = acc_val > best_acc
            best_acc = max(acc_val, best_acc)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                },
                is_best_loss,
                is_best_acc,
                filename='MLP_RNN')
            if not is_best_acc:
                val_accum_epochs += 1
            else:
                val_accum_epochs = 0
            if val_accum_epochs >= args.early_stop:
                print("validation acc did not improve over {} epochs, stop".
                      format(args.early_stop))
                break
예제 #14
0
import argparse
import os
import numpy as np
import cv2
import tensorflow as tf
from configuration import Config
from data.dataloader import DetectionDataset, DataLoader


def print_model_summary(network):
    sample_inputs = tf.random.normal(shape=(Config.batch_size,
                                            Config.get_image_size()[0],
                                            Config.get_image_size()[1],
                                            Config.image_channels))
    sample_outputs = network(sample_inputs, training=True)
    network.summary()


if __name__ == "__main__":
    gpu = tf.config.list_physical_devices("GPU: 0")

    #get MOT dataset
    train_dataset = DetectionDataset()
    train_data, train_size = train_dataset.generate_datatset()

    data_loader = DataLoader()
    steps_per_epoch = tf.math.ceil(train_size / Config.batch_size)

    centernet = CenterNet()
    print_model_summary(centernet)