Exemplo n.º 1
0
    def __init__(self):
        self.cfg = cfg

        # Load data
        print('===> Loading data')
        self.train_loader = load_data(
            cfg.DATASET, 'train') if 'train' in cfg.PHASE else None
        self.eval_loader = load_data(cfg.DATASET,
                                     'eval') if 'eval' in cfg.PHASE else None
        self.test_loader = load_data(cfg.DATASET,
                                     'test') if 'test' in cfg.PHASE else None
        self.visualize_loader = load_data(
            cfg.DATASET, 'visualize') if 'visualize' in cfg.PHASE else None

        # Build model
        print('===> Building model')
        self.model, self.priorbox = create_model(cfg.MODEL)
        with torch.no_grad():
            self.priors = torch.Tensor(self.priorbox.forward())
        self.detector = Detect(cfg.POST_PROCESS, self.priors)

        # Utilize GPUs for computation
        self.use_gpu = torch.cuda.is_available()
        self.device = torch.device("cuda:0" if self.use_gpu else "cpu")
        if self.use_gpu:
            print('Utilize GPUs for computation')
            print('Number of GPU available', torch.cuda.device_count())
            # self.model.cuda()
            # self.priors.cuda()
            self.model = self.model.to(device)
            self.priors = self.priors.to(device)
            cudnn.benchmark = True
            # if torch.cuda.device_count() > 1:
            # self.model = torch.nn.DataParallel(self.model).module

        # Print the model architecture and parameters
        print('Model architectures:\n{}\n'.format(self.model))

        # print('Parameters and size:')
        # for name, param in self.model.named_parameters():
        #     print('{}: {}'.format(name, list(param.size())))

        # print trainable scope
        print('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        trainable_param = self.trainable_param(cfg.TRAIN.TRAINABLE_SCOPE)
        self.optimizer = self.configure_optimizer(trainable_param,
                                                  cfg.TRAIN.OPTIMIZER)
        self.exp_lr_scheduler = self.configure_lr_scheduler(
            self.optimizer, cfg.TRAIN.LR_SCHEDULER)
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS

        # metric
        self.criterion = MultiBoxLoss(cfg.MATCHER, self.priors, self.use_gpu)

        # Set the logger
        self.writer = SummaryWriter(log_dir=cfg.LOG_DIR)
        self.output_dir = cfg.EXP_DIR
        self.checkpoint = cfg.RESUME_CHECKPOINT
        self.checkpoint_prefix = cfg.CHECKPOINTS_PREFIX
    def __init__(self):
        self.cfg = cfg

        # Load data
        logger.info('Loading data')
        self.loaders = {}
        if isinstance(cfg.PHASE, (str, unicode)):
            cfg.PHASE = [cfg.PHASE]

        for phase in cfg.PHASE:
            self.loaders[phase] = load_data(cfg.DATASET, phase)

        # Build model
        logger.info("Building model...")
        self.model, self.priorbox = create_model(cfg.MODEL)
        with torch.no_grad():
            self.priors = self.priorbox.forward()
        self.detector = Detect(cfg.POST_PROCESS, self.priors)

        # Utilize GPUs for computation
        self.use_gpu = torch.cuda.is_available()
        if self.use_gpu:
            logger.info('Utilize GPUs for computation')
            logger.info('Number of GPU available: {}'.format(
                torch.cuda.device_count()))
            self.model.cuda()
            self.priors.cuda()
            cudnn.benchmark = True
            # if torch.cuda.device_count() > 1:
            # self.model = torch.nn.DataParallel(self.model).module

        # Print the model architecture and parameters
        logger.info('Model architectures:\n{}\n'.format(self.model))

        logger.debug('Parameters and size:')
        for name, param in self.model.named_parameters():
            logger.debug('{}: {}'.format(name, list(param.size())))

        # print trainable scope
        logger.debug('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        trainable_param = self.trainable_param(cfg.TRAIN.TRAINABLE_SCOPE)
        self.optimizer = self.configure_optimizer(trainable_param,
                                                  cfg.TRAIN.OPTIMIZER)
        self.exp_lr_scheduler = self.configure_lr_scheduler(
            self.optimizer, cfg.TRAIN.LR_SCHEDULER)
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS

        # metric
        loss_func = FocalLoss if cfg.MATCHER.USE_FOCAL_LOSS else MultiBoxLoss
        self.criterion = loss_func(cfg.MATCHER, self.priors, self.use_gpu)

        # Set the logger
        self.writer = SummaryWriter(log_dir=cfg.LOG_DIR)
        self.output_dir = cfg.EXP_DIR
        self.checkpoint = cfg.RESUME_CHECKPOINT
        self.checkpoints_kept = cfg.TRAIN.CHECKPOINTS_KEPT
        self.checkpoint_prefix = cfg.CHECKPOINTS_PREFIX
Exemplo n.º 3
0
    def __init__(self):
        self.cfg = cfg

        # Load data
        print('===> Loading data')
        self.train_loader = load_data(cfg.DATASET, 'train') if 'train' in cfg.PHASE else None
        self.eval_loader = load_data(cfg.DATASET, 'eval') if 'eval' in cfg.PHASE else None
        self.test_loader = load_data(cfg.DATASET, 'test') if 'test' in cfg.PHASE else None
        self.visualize_loader = load_data(cfg.DATASET, 'visualize') if 'visualize' in cfg.PHASE else None

        if self.train_loader and hasattr(self.train_loader.dataset, "num_classes"):
            cfg.POST_PROCESS.NUM_CLASSES = cfg.MATCHER.NUM_CLASSES=cfg.MODEL.NUM_CLASSES=self.train_loader.dataset.num_classes
        elif self.eval_loader and hasattr(self.eval_loader.dataset, "num_classes"):
            cfg.POST_PROCESS.NUM_CLASSES = cfg.MATCHER.NUM_CLASSES=cfg.MODEL.NUM_CLASSES=self.eval_loader.dataset.num_classes
        elif self.test_loader and hasattr(self.test_loader.dataset, "num_classes"):
            cfg.POST_PROCESS.NUM_CLASSES = cfg.MATCHER.NUM_CLASSES=cfg.MODEL.NUM_CLASSES = self.test_loader.dataset.num_classes
        elif self.visualize_loader and hasattr(self.visualize_loader.dataset, "num_classes"):
            cfg.POST_PROCESS.NUM_CLASSES = cfg.MATCHER.NUM_CLASSES=cfg.MODEL.NUM_CLASSES = self.visualize_loader.dataset.num_classes

         # Build model
        print('===> Building model, num_classes is '+str(cfg.MODEL.NUM_CLASSES))

        self.model, self.priorbox = create_model(cfg.MODEL,cfg.LOSS.CONF_DISTR)
        self.priors = Variable(self.priorbox.forward(), volatile=True)
        self.detector = Detect(cfg.POST_PROCESS, self.priors)

        # Utilize GPUs for computation
        self.use_gpu = torch.cuda.is_available()
        if self.use_gpu:
            print('Utilize GPUs for computation')
            print('Number of GPU available', torch.cuda.device_count())
            self.model.cuda()
            self.priors.cuda()
            if torch.cuda.device_count() > 1:
                print('-----DataParallel-----------')
                self.model = torch.nn.DataParallel(self.model)
                self.model.cuda()
                #self.dp_model = torch.nn.DataParallel(self.model)
                #self.model = torch.nn.DataParallel(self.model).module
                #self.model = self.dp_model.module

            cudnn.benchmark = True

        # Print the model architecture and parameters
        #print('Model architectures:\n{}\n'.format(self.model))

        # print('Parameters and size:')
        # for name, param in self.model.named_parameters():
        #     print('{}: {}'.format(name, list(param.size())))

        # print trainable scope
        print('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        trainable_param = self.trainable_param(cfg.TRAIN.TRAINABLE_SCOPE)
        self.optimizer = self.configure_optimizer(trainable_param, cfg.TRAIN.OPTIMIZER)
        self.exp_lr_scheduler = self.configure_lr_scheduler(self.optimizer, cfg.TRAIN.LR_SCHEDULER)
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS

        # metric
        #self.criterion = MultiBoxLoss(cfg.MATCHER, self.priors, self.use_gpu)
        self.criterion = FocalLoss(cfg.MATCHER, self.priors, self.use_gpu, cfg.LOSS)

        # Set the logger
        self.writer = SummaryWriter(log_dir=cfg.LOG_DIR)
        self.output_dir = cfg.EXP_DIR
        self.checkpoint = cfg.RESUME_CHECKPOINT
        self.pretrained= cfg.PRETRAINED
        self.checkpoint_prefix = cfg.CHECKPOINTS_PREFIX
Exemplo n.º 4
0
    def __init__(self):
        self.cfg = cfg

        # set up logging
        if not os.path.exists(cfg.LOG_DIR):
            os.mkdir(cfg.LOG_DIR)


#         utils.setup_logging(os.path.join(cfg.LOG_DIR, "log.txt"))
        logging.basicConfig(filename=os.path.join(cfg.LOG_DIR, "log.txt"),
                            level=logging.INFO)

        # Load data
        logging.info('===> Loading data, phase %s' % cfg.PHASE)
        self.train_loader = load_data(
            cfg.DATASET, 'train') if 'train' in cfg.PHASE else None
        self.eval_loader = load_data(cfg.DATASET,
                                     'eval') if 'eval' in cfg.PHASE else None
        self.test_loader = load_data(cfg.DATASET,
                                     'test') if 'test' in cfg.PHASE else None
        self.visualize_loader = load_data(
            cfg.DATASET, 'visualize') if 'visualize' in cfg.PHASE else None

        # Build model
        logging.info('===> Building model')
        self.model, self.priorbox = create_model(cfg.MODEL)
        self.priors = self.priorbox.forward()
        self.detector = Detect(cfg.POST_PROCESS, self.priors)

        # Utilize GPUs for computation
        self.use_gpu = torch.cuda.is_available()
        if self.use_gpu:
            logging.info('Utilize GPUs for computation')
            logging.info('Number of GPU available: %d' %
                         torch.cuda.device_count())
            self.model.cuda()
            self.priors.cuda()
            cudnn.benchmark = True
            # if torch.cuda.device_count() > 1:
            # self.model = torch.nn.DataParallel(self.model).module

        # Print the model architecture and parameters
        # logging.info('Model architectures:\n{}\n'.format(self.model))

        # logging.info('Parameters and size:')
        # for name, param in self.model.named_parameters():
        #     logging.info('{}: {}'.format(name, list(param.size())))

        # logging.info trainable scope
        logging.info('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        trainable_param = self.trainable_param(cfg.TRAIN.TRAINABLE_SCOPE)
        self.optimizer = self.configure_optimizer(trainable_param,
                                                  cfg.TRAIN.OPTIMIZER)
        self.exp_lr_scheduler = self.configure_lr_scheduler(
            self.optimizer, cfg.TRAIN.LR_SCHEDULER)
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS

        # metric
        self.criterion = MultiBoxLoss(cfg.MATCHER, self.priors, self.use_gpu)

        # Set the logger
        self.writer = SummaryWriter(log_dir=cfg.LOG_DIR)
        self.output_dir = cfg.EXP_DIR
        self.checkpoint = cfg.RESUME_CHECKPOINT
        self.checkpoint_prefix = cfg.CHECKPOINTS_PREFIX
Exemplo n.º 5
0
    def __init__(self, ifTrain=True):
        self.cfg = cfg

        # Load data
        print('===> Loading data')
        self.ifTrain = ifTrain
        if self.ifTrain:
            self.train_loader = load_data(
                cfg.DATASET, 'train') if 'train' in cfg.PHASE else None
            #self.eval_loader = load_data(cfg.DATASET, 'eval') if 'eval' in cfg.PHASE else None
        else:
            test_image_dir = os.path.join('./data/', 'ship_test_v2')
            #  transforms = transform.Compose([transform.Lambda(lambda x: cv2.cvtColor(np.asarray(x),cv2.COLOR_RGB2BGR)),transform.Resize([300,300]), transform.ToTensor()])

            #  test_set = torchvision.datasets.ImageFolder(test_image_dir, transform = transforms)

            #  self.test_loader = torch.utils.data.DataLoader(test_set,batch_size=8,shuffle=False,num_workers=8)
            self.train_loader = load_data(
                cfg.DATASET, 'train') if 'train' in cfg.PHASE else None
            #self.test_loader = load_data(cfg.DATASET, 'test') if 'test' in cfg.PHASE else None
        self.visualize_loader = load_data(
            cfg.DATASET, 'visualize') if 'visualize' in cfg.PHASE else None

        # Build model
        print('===> Building model')
        self.model, self.priorbox = create_model(cfg.MODEL)
        self.priors = Variable(self.priorbox.forward(), volatile=True)
        self.detector = Detect(cfg.POST_PROCESS, self.priors)

        # Utilize GPUs for computation
        self.use_gpu = torch.cuda.is_available()
        #self.use_gpu = False
        if self.use_gpu:
            print('Utilize GPUs for computation')
            print('Number of GPU available', torch.cuda.device_count())
            self.model.cuda()
            self.priors.cuda()
            cudnn.benchmark = True
            if torch.cuda.device_count() > 1:
                self.model = torch.nn.DataParallel(self.model).module
        #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        #device = torch.device("cpu")
        #self.model = self.model.to(device)
        # Print the model architecture and parameters
        print('Model architectures:\n{}\n'.format(self.model))

        # print('Parameters and size:')
        # for name, param in self.model.named_parameters():
        #     print('{}: {}'.format(name, list(param.size())))

        # print trainable scope
        print('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        trainable_param = self.trainable_param(cfg.TRAIN.TRAINABLE_SCOPE)
        # print('trainable_param ', trainable_param)
        self.optimizer = self.configure_optimizer(trainable_param,
                                                  cfg.TRAIN.OPTIMIZER)
        self.exp_lr_scheduler = self.configure_lr_scheduler(
            self.optimizer, cfg.TRAIN.LR_SCHEDULER)
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS

        # metric
        # print('priors ', self.priors)
        self.criterion = MultiBoxLoss(cfg.MATCHER, self.priors, self.use_gpu)

        # Set the logger
        self.writer = SummaryWriter(log_dir=cfg.LOG_DIR)
        self.output_dir = cfg.EXP_DIR
        self.checkpoint = cfg.RESUME_CHECKPOINT
        self.checkpoint_prefix = cfg.CHECKPOINTS_PREFIX
Exemplo n.º 6
0
    def __init__(self):

        self.cfg = cfg
        self.use_gpu = torch.cuda.is_available()
        if self.use_gpu:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        # Load data
        print('===> Loading data')
        if 'train_mimic' == cfg.PHASE[0] or 'train' == cfg.PHASE[0]:
            self.train_loader_1 = load_data(cfg.DATASET, 'train')
            print(cfg.DATASET.DATASET, len(self.train_loader_1))
            if cfg.DATASET2.DATASET in cfg.DATASET.DATASETS:
                self.train_loader_2 = load_data(cfg.DATASET2, 'train')
                print(cfg.DATASET2.DATASET, len(self.train_loader_1))
            else:
                self.train_loader_2 = None

        self.test_loader = load_data(cfg.DATASET,
                                     'test') if 'test' in cfg.PHASE1 else None
        self.corr_loader = load_data(
            cfg.DATASET,
            'correlation') if 'correlation' in cfg.PHASE1 else None

        print('===> Building model')
        self.model, self.priorbox, feature_maps = create_model(cfg.MODEL)
        sizes = []
        boxes = []
        for maps in feature_maps:
            sizes.append(maps[0] * maps[1])
        for box in cfg.MODEL.ASPECT_RATIOS:
            boxes.append(len(box) * 2)
        self.priors = Variable(self.priorbox.forward(), volatile=True)
        self.detector = Detect_fast(cfg.POST_PROCESS, self.priors)

        self.Discriminator = Discriminator(cfg.LOG_DIR, cfg.MODEL.NETS,
                                           cfg.MODEL_MIMIC.NETS,
                                           cfg.DISCTRIMINATOR, sizes, boxes)
        self.Correlation = Correlation(cfg.CORRELATION, sizes, boxes)
        self.Trainer = Trainer()
        self.Traier_mimic = Traier_mimic(cfg.TRAIN_MIMIC, sizes, boxes,
                                         cfg.DISCTRIMINATOR.TYPE)
        self.Tester = Tester(cfg.POST_PROCESS)

        if 'train_mimic' == cfg.PHASE[0] or 'correlation' == cfg.PHASE[0]:
            self.model_mimic = create_model(cfg.MODEL_MIMIC)
            self.model_mimic.load_state_dict(
                torch.load(cfg.MODEL_MIMIC.WEIGHTS))

            self.DNet = self.Discriminator.create_discriminator()

        # Utilize GPUs for computation
        if self.use_gpu:
            print('Utilize GPUs for computation')
            print('Number of GPU available', torch.cuda.device_count())
            cudnn.benchmark = True
            # self.model_mimic = torch.nn.DataParallel(self.model_mimic)
            # for i in range(len(self.DNet)):
            #     self.DNet[i] = torch.nn.DataParallel(self.DNet[i])
            self.model.cuda()
            self.priors.cuda()
            if 'train_mimic' == cfg.PHASE[0] or 'correlation' == cfg.PHASE[0]:
                self.model_mimic.cuda()
                for i in range(len(self.DNet)):
                    self.DNet[i] = self.DNet[i].cuda()

        # if 'train_mimic' == cfg.PHASE[0]:
        # print('Model mimim architectures:\n{}\n'.format(self.model_mimic))
        # for i in range(len(self.DNet)):
        #   print('Hello')
        #   print(self.DNet[i])
        # print('Parameters and size:')
        # for name, param in self.model.named_parameters():
        #     print('{}: {}'.format(name, list(param.size())))

        # print trainable scope
        # print('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        if 'train_mimic' == cfg.PHASE[0] or 'correlation' == cfg.PHASE[0]:
            self.optimizer = self.configure_optimizer(self.model,
                                                      cfg.TRAIN.OPTIMIZER)
            self.DNet_optim = []
            for i in range(len(self.DNet)):
                self.DNet_optim.append(
                    self.configure_optimizer(self.DNet[i],
                                             cfg.DISCTRIMINATOR.OPTIMIZER))
            self.optimizer_GENERATOR = self.configure_optimizer(
                self.model, cfg.TRAIN_MIMIC.OPTIMIZER)
            self.exp_lr_scheduler_g = self.configure_lr_scheduler(
                self.optimizer_GENERATOR, cfg.TRAIN.LR_SCHEDULER)
        else:
            self.optimizer = self.configure_optimizer(self.model,
                                                      cfg.TRAIN.OPTIMIZER)

        self.phase = cfg.PHASE
        self.exp_lr_scheduler = self.configure_lr_scheduler(
            self.optimizer, cfg.TRAIN.LR_SCHEDULER)
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS

        # metric
        self.criterion = MultiBoxLoss(cfg.MATCHER, self.priors, self.use_gpu)
        self.criterion_GaN = nn.BCELoss()
        self.pos = POSdata(cfg.MATCHER, self.priors, self.use_gpu)

        # Set the logger
        self.writer = SummaryWriter(log_dir=cfg.LOG_DIR)
        if not os.path.exists(cfg.LOG_DIR):
            os.mkdir(cfg.LOG_DIR)
        shutil.copyfile('./lib/utils/config_parse.py',
                        cfg.LOG_DIR + 'hiperparameters.py')
        a = os.listdir(cfg.LOG_DIR)
        for i in range(1, 100):
            if not 'Correlation_' + str(i) + '.txt' in a:
                self.logger = cfg.LOG_DIR + 'Correlation_' + str(i) + '.txt'
                self.loglosses = cfg.LOG_DIR + 'Correlation_loss_' + str(
                    i) + '.txt'
                break
        f = open(self.logger, 'w')
        f.close()
        f = open(self.loglosses, 'w')
        f.close()
        self.output_dir = cfg.LOG_DIR
        self.checkpoint = cfg.RESUME_CHECKPOINT
        self.checkpoint_prefix = cfg.CHECKPOINTS_PREFIX
        self.model.loc.apply(self.weights_init)
        self.model.conf.apply(self.weights_init)
        self.model.extras.apply(self.weights_init)
        if 'train' == cfg.PHASE[0]:
            if torch.cuda.device_count() > 1:
                self.model = torch.nn.DataParallel(self.model)
Exemplo n.º 7
0
    def __init__(self, phase="train"):
        self.cfg = cfg
        creat_log(self.cfg, phase=phase)
        for k, v in cfg.items():
            print(k, ": ", v)
            log_str = '\rEpoch {k}: {v}'.format(k=k, v=v)
            logging.info(log_str)
            # Load data
        print('===> Loading data')
        logging.info('===> Loading data')
        self.train_loader = load_data(
            cfg.DATASET, 'train') if 'train' in cfg.PHASE else None
        self.eval_loader = load_data(cfg.DATASET,
                                     'eval') if 'eval' in cfg.PHASE else None
        self.test_loader = load_data(cfg.DATASET,
                                     'test') if 'test' in cfg.PHASE else None
        self.visualize_loader = load_data(
            cfg.DATASET, 'visualize') if 'visualize' in cfg.PHASE else None

        # Build model
        print('===> Building model')
        logging.info('===> Building model')
        self.model, self.priorbox = create_model(cfg.MODEL)
        with torch.no_grad():
            self.priors = Variable(self.priorbox.forward())
        self.detector = Detect(cfg.POST_PROCESS, self.priors)
        os.makedirs(self.cfg['EXP_DIR'], exist_ok=True)

        # Utilize GPUs for computation
        self.use_gpu = torch.cuda.is_available()
        print('Model architectures:\n{}\n'.format(self.model))
        logging.info('Model architectures:\n{}\n'.format(self.model))

        from lib.utils.torchsummary import summary
        summary_text = summary(
            self.model.cuda(),
            (3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]))
        logging.info('\n'.join(summary_text))
        # num_params = 0
        # for name, param in self.model.named_parameters():
        #     num_params += param.numel()
        #     # print(name, param.size(), param.numel())
        #     print("%40s %20s  %20s" % (name, num_params, param.numel()))
        # print(num_params/1e4)
        # df = torch_summarize_df(input_size=(3, 512, 512), model=self.model)
        # # df['name'], list(df['class_name']), df['input_shape'], df["output_shape"], list(df['nb_params'])
        # print(df)
        # for name, param in self.model.named_parameters():
        #     print(name, param.size())
        # from thop import profile
        #
        # flops, params = profile(self.model, input_size=(1, 3, 512, 128))
        # count = 0
        # for p in self.model.parameters():
        #     count += p.data.nelement()
        # self.multi_gpu = True
        self.multi_gpu = False
        if self.use_gpu:
            print('Utilize GPUs for computation')
            logging.info('Utilize GPUs for computation')
            # print('Number of GPU available', torch.cuda.device_count())
            self.model.cuda()
            self.priors.cuda()
            cudnn.benchmark = True
            # os.environ['CUDA_VISIBLE_DEVICES'] = "4,5,6,7"  # "0,1,2,3,4,5,6,7"
            if torch.cuda.device_count() > 1 and self.multi_gpu:
                self.model = torch.nn.DataParallel(self.model.cuda())
                cudnn.benchmark = True
                # self.model = torch.nn.DataParallel(self.model).module

        # Print the model architecture and parameters

        # print('Parameters and size:')
        # for name, param in self.model.named_parameters():
        #     print('{}: {}'.format(name, list(param.size())))

        # print trainable scope
        print('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        logging.info('Trainable scope: {}'.format(cfg.TRAIN.TRAINABLE_SCOPE))
        trainable_param = self.trainable_param(cfg.TRAIN.TRAINABLE_SCOPE)
        self.optimizer = self.configure_optimizer(trainable_param,
                                                  cfg.TRAIN.OPTIMIZER)
        self.exp_lr_scheduler = self.configure_lr_scheduler(
            self.optimizer, cfg.TRAIN.LR_SCHEDULER)
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS

        # metric
        self.criterion = MultiBoxLoss(cfg.MATCHER, self.priors, self.use_gpu)

        # Set the logger
        self.writer = SummaryWriter(logdir=cfg.LOG_DIR)
        self.output_dir = cfg.EXP_DIR
        self.checkpoint = cfg.RESUME_CHECKPOINT
        self.checkpoint_prefix = cfg.CHECKPOINTS_PREFIX