def create_generator(args, params, input_size):
    # Visualize anchors to see if they match for per image
    # For visualization purpose, no argumentation will be applied at here, else you may see distortion.
    # Except resize op, as you will resize your image when you train it anyway.
    training_params = {'batch_size': 1,
                       'shuffle': False,
                       'drop_last': True,
                       'collate_fn': collater,
                       'num_workers': args.num_workers}

    val_params = {'batch_size': 1,
                  'shuffle': False,
                  'drop_last': True,
                  'collate_fn': collater,
                  'num_workers': args.num_workers}
    training_set = CocoDataset(root_dir=args.dataset_path, set=params.train_set,
                               transform=transforms.Compose([Resizer(input_size)]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=args.dataset_path, set=params.val_set,
                          transform=transforms.Compose([Resizer(input_size)]))
    val_generator = DataLoader(val_set, **val_params)

    return training_generator, val_generator
def _get_val_data_loader(args):
    logger.info("Getting val data loader")
    dataset = CocoDataset(
        root_dir=args.data_dir,
        set="val",
        transform=transforms.Compose([
            Normalizer(mean=args.mean, std=args.std),
            Resizer(_INPUT_SIZES[args.compound_coef]),
        ]),
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size * num_gpus,
        shuffle=False,
        drop_last=True,
        collate_fn=collater,
        num_workers=args.num_workers
        if args.num_workers >= 0 else args.batch_size * num_gpus,
    )
Esempio n. 3
0
def train(opt):
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1356]
    training_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                     params.project_name),
                               set=params.train_set,
                               transform=transforms.Compose([
                                   Normalizer(mean=params.mean,
                                              std=params.std),
                                   Augmenter(),
                                   Resizer(input_sizes[opt.compound_coef])
                               ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                params.project_name),
                          set=params.val_set,
                          transform=transforms.Compose([
                              Normalizer(mean=params.mean, std=params.std),
                              Resizer(input_sizes[opt.compound_coef])
                          ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
            )

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['img']
                    annot = data['annot']

                    ################just GT check#########################
                    # img_sample = imgs[0,:,:,:]
                    # annot_sample = annot[0,:,:]
                    # img_out = img_sample.numpy()
                    # img_out = np.transpose(img_out, (1,2,0))
                    # img_out = cv2.cvtColor(img_out, cv2.COLOR_RGB2BGR)
                    # annot_out = annot_sample.numpy()
                    # count, _ = annot_out.shape
                    # for i in range(count):
                    #     if annot_out[i,4]  >= 0:
                    #         cv2.rectangle(img_out, (int(annot_out[i,0]),int(annot_out[i,1])), (int(annot_out[i,2]),int(annot_out[i,3])), (255,0,0), 1)
                    #
                    # cv2.imwrite("test.png", img_out*255)
                    #
                    #######################################################

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs,
                                               annot,
                                               obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch, opt.num_epochs, iter + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model,
                            f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                        )
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['img']
                        annot = data['annot']

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs,
                                                   annot,
                                                   obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch, opt.num_epochs, cls_loss, reg_loss, loss))
                writer.add_scalars('Loss', {'val': loss}, step)
                writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
                writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                   step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(
                        model,
                        f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                    )

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(
            model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
        writer.close()
    writer.close()
Esempio n. 4
0
def train(opt):
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
    training_set = CocoDataset(root_dir=opt.data_path + params.project_name,
                               set=params.train_set,
                               transform=transforms.Compose([
                                   Normalizer(mean=params.mean,
                                              std=params.std),
                                   Augmenter(),
                                   Resizer(input_sizes[opt.compound_coef])
                               ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=opt.data_path + params.project_name,
                          set=params.val_set,
                          transform=transforms.Compose([
                              Normalizer(mean=params.mean, std=params.std),
                              Resizer(input_sizes[opt.compound_coef])
                          ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_anchors=9,
                                 num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef)

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0
        model.load_state_dict(torch.load(weights_path))
        print(
            f'loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    if params.num_gpus > 0:
        model = model.cuda()
        model = CustomDataParallel(model, params.num_gpus)

    optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    criterion = FocalLoss()

    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)
    for epoch in range(opt.num_epochs):
        try:
            model.train()
            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                try:
                    imgs = data['img']
                    annot = data['annot']

                    if params.num_gpus > 0:
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    _, regression, classification, anchors = model(imgs)

                    cls_loss, reg_loss = criterion(
                        classification,
                        regression,
                        anchors,
                        annot,
                        # imgs=imgs, obj_list=params.obj_list  # uncomment this to debug
                    )

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch + 1, opt.num_epochs, iter + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                except Exception as e:
                    print(traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if step % opt.save_interval == 0 and step > 0:
                save_checkpoint(
                    model,
                    f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['img']
                        annot = data['annot']

                        if params.num_gpus > 0:
                            annot = annot.cuda()
                        _, regression, classification, anchors = model(imgs)
                        cls_loss, reg_loss = criterion(classification,
                                                       regression, anchors,
                                                       annot)

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch + 1, opt.num_epochs, cls_loss, reg_loss,
                            loss.mean()))
                writer.add_scalars('Total_loss', {'val': loss}, step)
                writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
                writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                   step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(
                        model,
                        f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                    )

                    # onnx export is not tested.
                    # dummy_input = torch.rand(opt.batch_size, 3, 512, 512)
                    # if torch.cuda.is_available():
                    #     dummy_input = dummy_input.cuda()
                    # if isinstance(model, nn.DataParallel):
                    #     model.module.backbone_net.model.set_swish(memory_efficient=False)
                    #
                    #     torch.onnx.export(model.module, dummy_input,
                    #                       os.path.join(opt.saved_path, 'signatrix_efficientdet_coco.onnx'),
                    #                       verbose=False)
                    #     model.module.backbone_net.model.set_swish(memory_efficient=True)
                    # else:
                    #     model.backbone_net.model.set_swish(memory_efficient=False)
                    #
                    #     torch.onnx.export(model, dummy_input,
                    #                       os.path.join(opt.saved_path, 'signatrix_efficientdet_coco.onnx'),
                    #                       verbose=False)
                    #     model.backbone_net.model.set_swish(memory_efficient=True)

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        'Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, loss))
                    break
            writer.close()
        except KeyboardInterrupt:
            save_checkpoint(
                model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
    def start_training(self):
        if self.system_dict["params"]["num_gpus"] == 0:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

        if torch.cuda.is_available():
            torch.cuda.manual_seed(42)
        else:
            torch.manual_seed(42)

        self.system_dict["params"]["saved_path"] = self.system_dict["params"][
            "saved_path"] + "/" + self.system_dict["params"][
                "project_name"] + "/"
        self.system_dict["params"]["log_path"] = self.system_dict["params"][
            "log_path"] + "/" + self.system_dict["params"][
                "project_name"] + "/tensorboard/"
        os.makedirs(self.system_dict["params"]["saved_path"], exist_ok=True)
        os.makedirs(self.system_dict["params"]["log_path"], exist_ok=True)

        training_params = {
            'batch_size': self.system_dict["params"]["batch_size"],
            'shuffle': True,
            'drop_last': True,
            'collate_fn': collater,
            'num_workers': self.system_dict["params"]["num_workers"]
        }

        val_params = {
            'batch_size': self.system_dict["params"]["batch_size"],
            'shuffle': False,
            'drop_last': True,
            'collate_fn': collater,
            'num_workers': self.system_dict["params"]["num_workers"]
        }

        input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
        training_set = CocoDataset(
            self.system_dict["dataset"]["train"]["root_dir"],
            self.system_dict["dataset"]["train"]["coco_dir"],
            self.system_dict["dataset"]["train"]["img_dir"],
            set_dir=self.system_dict["dataset"]["train"]["set_dir"],
            transform=transforms.Compose([
                Normalizer(mean=self.system_dict["params"]["mean"],
                           std=self.system_dict["params"]["std"]),
                Augmenter(),
                Resizer(
                    input_sizes[self.system_dict["params"]["compound_coef"]])
            ]))
        training_generator = DataLoader(training_set, **training_params)

        if (self.system_dict["dataset"]["val"]["status"]):
            val_set = CocoDataset(
                self.system_dict["dataset"]["val"]["root_dir"],
                self.system_dict["dataset"]["val"]["coco_dir"],
                self.system_dict["dataset"]["val"]["img_dir"],
                set_dir=self.system_dict["dataset"]["val"]["set_dir"],
                transform=transforms.Compose([
                    Normalizer(self.system_dict["params"]["mean"],
                               self.system_dict["params"]["std"]),
                    Resizer(input_sizes[self.system_dict["params"]
                                        ["compound_coef"]])
                ]))
            val_generator = DataLoader(val_set, **val_params)

        print("")
        print("")
        model = EfficientDetBackbone(
            num_classes=len(self.system_dict["params"]["obj_list"]),
            compound_coef=self.system_dict["params"]["compound_coef"],
            ratios=eval(self.system_dict["params"]["anchors_ratios"]),
            scales=eval(self.system_dict["params"]["anchors_scales"]))

        os.makedirs("pretrained_weights", exist_ok=True)

        if (self.system_dict["params"]["compound_coef"] == 0):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d0.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 1):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d1.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 2):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d2.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 3):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d3.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 4):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d4.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 5):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d5.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 6):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d6.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 7):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d7.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)

        # load last weights
        if self.system_dict["params"]["load_weights"] is not None:
            if self.system_dict["params"]["load_weights"].endswith('.pth'):
                weights_path = self.system_dict["params"]["load_weights"]
            else:
                weights_path = get_last_weights(
                    self.system_dict["params"]["saved_path"])
            try:
                last_step = int(
                    os.path.basename(weights_path).split('_')[-1].split('.')
                    [0])
            except:
                last_step = 0

            try:
                ret = model.load_state_dict(torch.load(weights_path),
                                            strict=False)
            except RuntimeError as e:
                print(f'[Warning] Ignoring {e}')
                print(
                    '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
                )

            print(
                f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
            )
        else:
            last_step = 0
            print('[Info] initializing weights...')
            init_weights(model)

        print("")
        print("")

        # freeze backbone if train head_only
        if self.system_dict["params"]["head_only"]:

            def freeze_backbone(m):
                classname = m.__class__.__name__
                for ntl in ['EfficientNet', 'BiFPN']:
                    if ntl in classname:
                        for param in m.parameters():
                            param.requires_grad = False

            model.apply(freeze_backbone)
            print('[Info] freezed backbone')

        print("")
        print("")

        if self.system_dict["params"]["num_gpus"] > 1 and self.system_dict[
                "params"]["batch_size"] // self.system_dict["params"][
                    "num_gpus"] < 4:
            model.apply(replace_w_sync_bn)
            use_sync_bn = True
        else:
            use_sync_bn = False

        writer = SummaryWriter(
            self.system_dict["params"]["log_path"] +
            f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

        model = ModelWithLoss(model, debug=self.system_dict["params"]["debug"])

        if self.system_dict["params"]["num_gpus"] > 0:
            model = model.cuda()
            if self.system_dict["params"]["num_gpus"] > 1:
                model = CustomDataParallel(
                    model, self.system_dict["params"]["num_gpus"])
                if use_sync_bn:
                    patch_replication_callback(model)

        if self.system_dict["params"]["optim"] == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(),
                                          self.system_dict["params"]["lr"])
        else:
            optimizer = torch.optim.SGD(model.parameters(),
                                        self.system_dict["params"]["lr"],
                                        momentum=0.9,
                                        nesterov=True)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               patience=3,
                                                               verbose=True)

        epoch = 0
        best_loss = 1e5
        best_epoch = 0
        step = max(0, last_step)
        model.train()

        num_iter_per_epoch = len(training_generator)

        try:
            for epoch in range(self.system_dict["params"]["num_epochs"]):
                last_epoch = step // num_iter_per_epoch
                if epoch < last_epoch:
                    continue

                epoch_loss = []
                progress_bar = tqdm(training_generator)
                for iter, data in enumerate(progress_bar):
                    if iter < step - last_epoch * num_iter_per_epoch:
                        progress_bar.update()
                        continue
                    try:
                        imgs = data['img']
                        annot = data['annot']

                        if self.system_dict["params"]["num_gpus"] == 1:
                            # if only one gpu, just send it to cuda:0
                            # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        optimizer.zero_grad()
                        cls_loss, reg_loss = model(
                            imgs,
                            annot,
                            obj_list=self.system_dict["params"]["obj_list"])
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss.backward()
                        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                        optimizer.step()

                        epoch_loss.append(float(loss))

                        progress_bar.set_description(
                            'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                            .format(step, epoch,
                                    self.system_dict["params"]["num_epochs"],
                                    iter + 1, num_iter_per_epoch,
                                    cls_loss.item(), reg_loss.item(),
                                    loss.item()))
                        writer.add_scalars('Loss', {'train': loss}, step)
                        writer.add_scalars('Regression_loss',
                                           {'train': reg_loss}, step)
                        writer.add_scalars('Classfication_loss',
                                           {'train': cls_loss}, step)

                        # log learning_rate
                        current_lr = optimizer.param_groups[0]['lr']
                        writer.add_scalar('learning_rate', current_lr, step)

                        step += 1

                        if step % self.system_dict["params"][
                                "save_interval"] == 0 and step > 0:
                            self.save_checkpoint(
                                model,
                                f'efficientdet-d{self.system_dict["params"]["compound_coef"]}_trained.pth'
                            )
                            #print('checkpoint...')

                    except Exception as e:
                        print('[Error]', traceback.format_exc())
                        print(e)
                        continue
                scheduler.step(np.mean(epoch_loss))

                if epoch % self.system_dict["params"][
                        "val_interval"] == 0 and self.system_dict["dataset"][
                            "val"]["status"]:
                    print("Running validation")
                    model.eval()
                    loss_regression_ls = []
                    loss_classification_ls = []
                    for iter, data in enumerate(val_generator):
                        with torch.no_grad():
                            imgs = data['img']
                            annot = data['annot']

                            if self.system_dict["params"]["num_gpus"] == 1:
                                imgs = imgs.cuda()
                                annot = annot.cuda()

                            cls_loss, reg_loss = model(
                                imgs,
                                annot,
                                obj_list=self.system_dict["params"]
                                ["obj_list"])
                            cls_loss = cls_loss.mean()
                            reg_loss = reg_loss.mean()

                            loss = cls_loss + reg_loss
                            if loss == 0 or not torch.isfinite(loss):
                                continue

                            loss_classification_ls.append(cls_loss.item())
                            loss_regression_ls.append(reg_loss.item())

                    cls_loss = np.mean(loss_classification_ls)
                    reg_loss = np.mean(loss_regression_ls)
                    loss = cls_loss + reg_loss

                    print(
                        'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                        .format(epoch,
                                self.system_dict["params"]["num_epochs"],
                                cls_loss, reg_loss, loss))
                    writer.add_scalars('Loss', {'val': loss}, step)
                    writer.add_scalars('Regression_loss', {'val': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                       step)

                    if loss + self.system_dict["params"][
                            "es_min_delta"] < best_loss:
                        best_loss = loss
                        best_epoch = epoch

                        self.save_checkpoint(
                            model,
                            f'efficientdet-d{self.system_dict["params"]["compound_coef"]}_trained.pth'
                        )

                    model.train()

                    # Early stopping
                    if epoch - best_epoch > self.system_dict["params"][
                            "es_patience"] > 0:
                        print(
                            '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                            .format(epoch, best_loss))
                        break
        except KeyboardInterrupt:
            self.save_checkpoint(
                model,
                f'efficientdet-d{self.system_dict["params"]["compound_coef"]}_trained.pth'
            )
            writer.close()
        writer.close()

        print("")
        print("")
        print("Training complete")
Esempio n. 6
0
    'shuffle': False,
    'drop_last': True,
    'collate_fn': collater,
    'num_workers': opt.num_workers
}

# tf bilinear interpolation is different from any other's, just make do
input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
input_size = input_sizes[
    opt.
    compound_coef] if opt.force_input_size is None else opt.force_input_size

val_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                            params.project_name),
                      set=params.val_set,
                      transform=transforms.Compose([
                          Normalizer(mean=params.mean, std=params.std),
                          Resizer(input_sizes[opt.compound_coef])
                      ]))
val_generator = DataLoader(val_set, **val_params)

model = EfficientDetBackbone(compound_coef=opt.compound_coef,
                             num_classes=len(obj_list),
                             ratios=anchor_ratios,
                             scales=anchor_scales)
model.load_state_dict(
    torch.load(
        f'logs/{opt.project}/efficientdet-d{opt.compound_coef}_{opt.weights}.pth',
        map_location='cpu'))
model.requires_grad_(False)
model.eval()
Esempio n. 7
0
def train(opt):
    '''
    Input: get_args()
    Function: Train the model.
    '''
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    # evaluation json file
    pred_folder = f'{OPT.data_path}/{OPT.project}/predictions'
    os.makedirs(pred_folder, exist_ok=True)
    evaluation_pred_file = f'{pred_folder}/instances_bbox_results.json'

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
    training_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                     params.project_name),
                               set=params.train_set,
                               transform=torchvision.transforms.Compose([
                                   Normalizer(mean=params.mean,
                                              std=params.std),
                                   Augmenter(),
                                   Resizer(input_sizes[opt.compound_coef])
                               ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                params.project_name),
                          set=params.val_set,
                          transform=torchvision.transforms.Compose([
                              Normalizer(mean=params.mean, std=params.std),
                              Resizer(input_sizes[opt.compound_coef])
                          ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except Exception as exception:
            last_step = 0

        try:
            _ = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as rerror:
            print(f'[Warning] Ignoring {rerror}')
            print('[Warning] Don\'t panic if you see this, '\
                  'this might be because you load a pretrained weights with different number of classes.'\
                  ' The rest of the weights should be loaded already.')

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(mdl):
            classname = mdl.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in mdl.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)
    num_val_iter_per_epoch = len(val_generator)
    # Limit the no.of preds to #images in val.
    # Here, I averaged the #obj to 5 for computational efficacy
    if opt.max_preds_toeval > 0:
        opt.max_preds_toeval = len(val_generator) * opt.batch_size * 5

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iternum, data in enumerate(progress_bar):
                if iternum < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['img']
                    annot = data['annot']
                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    if iternum % int(num_iter_per_epoch *
                                     (opt.eval_percent_epoch / 100)) != 0:
                        model.debug = False
                        cls_loss, reg_loss, _ = model(imgs,
                                                      annot,
                                                      obj_list=params.obj_list)
                    else:
                        model.debug = True
                        cls_loss, reg_loss, imgs_labelled = model(
                            imgs, annot, obj_list=params.obj_list)

                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch, opt.num_epochs, iternum + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    if iternum % int(
                            num_iter_per_epoch *
                        (opt.eval_percent_epoch / 100)) == 0 and step > 0:
                        # create grid of images
                        imgs_labelled = np.asarray(imgs_labelled)
                        imgs_labelled = torch.from_numpy(
                            imgs_labelled)  # (N, H, W, C)
                        imgs_labelled.transpose_(1, 3)  # (N, C, H, W)
                        imgs_labelled.transpose_(2, 3)
                        img_grid = torchvision.utils.make_grid(imgs_labelled)
                        # write to tensorboard
                        writer.add_image('Training_images',
                                         img_grid,
                                         global_step=step)
                        #########################################################start EVAL#####################################################
                        model.eval()
                        model.debug = False  # Don't print images in tensorboard now.

                        # remove json
                        if os.path.exists(evaluation_pred_file):
                            os.remove(evaluation_pred_file)

                        loss_regression_ls = []
                        loss_classification_ls = []
                        model.evalresults = [
                        ]  # Empty the results for next evaluation.
                        imgs_to_viz = []
                        num_validation_steps = int(
                            num_val_iter_per_epoch *
                            (opt.eval_sampling_percent / 100))
                        for valiternum, valdata in enumerate(val_generator):
                            with torch.no_grad():
                                imgs = valdata['img']
                                annot = valdata['annot']
                                resizing_imgs_scales = valdata['scale']
                                new_ws = valdata['new_w']
                                new_hs = valdata['new_h']
                                imgs_ids = valdata['img_id']

                                if params.num_gpus >= 1:
                                    imgs = imgs.cuda()
                                    annot = annot.cuda()

                                if valiternum % (num_validation_steps //
                                                 (opt.num_visualize_images //
                                                  opt.batch_size)) != 0:
                                    model.debug = False
                                    cls_loss, reg_loss, _ = model(
                                        imgs,
                                        annot,
                                        obj_list=params.obj_list,
                                        resizing_imgs_scales=
                                        resizing_imgs_scales,
                                        new_ws=new_ws,
                                        new_hs=new_hs,
                                        imgs_ids=imgs_ids)
                                else:
                                    model.debug = True
                                    cls_loss, reg_loss, val_imgs_labelled = model(
                                        imgs,
                                        annot,
                                        obj_list=params.obj_list,
                                        resizing_imgs_scales=
                                        resizing_imgs_scales,
                                        new_ws=new_ws,
                                        new_hs=new_hs,
                                        imgs_ids=imgs_ids)

                                    imgs_to_viz += list(val_imgs_labelled)

                                loss_classification_ls.append(cls_loss.item())
                                loss_regression_ls.append(reg_loss.item())

                            if valiternum > (num_validation_steps):
                                break

                        cls_loss = np.mean(loss_classification_ls)
                        reg_loss = np.mean(loss_regression_ls)
                        loss = cls_loss + reg_loss

                        print(
                            'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                            .format(epoch, opt.num_epochs, cls_loss, reg_loss,
                                    loss))
                        writer.add_scalars('Loss', {'val': loss}, step)
                        writer.add_scalars('Regression_loss',
                                           {'val': reg_loss}, step)
                        writer.add_scalars('Classfication_loss',
                                           {'val': cls_loss}, step)
                        # create grid of images
                        val_imgs_labelled = np.asarray(imgs_to_viz)
                        val_imgs_labelled = torch.from_numpy(
                            val_imgs_labelled)  # (N, H, W, C)
                        val_imgs_labelled.transpose_(1, 3)  # (N, C, H, W)
                        val_imgs_labelled.transpose_(2, 3)
                        val_img_grid = torchvision.utils.make_grid(
                            val_imgs_labelled, nrow=2)
                        # write to tensorboard
                        writer.add_image('Eval_Images', val_img_grid, \
                                         global_step=(step))

                        if opt.max_preds_toeval > 0:
                            json.dump(model.evalresults,
                                      open(evaluation_pred_file, 'w'),
                                      indent=4)
                            try:
                                val_results = calc_mAP_fin(params.project_name,\
                                                        params.val_set, evaluation_pred_file, \
                                                        val_gt=f'{OPT.data_path}/{OPT.project}/annotations/instances_{params.val_set}.json')

                                for catgname in val_results:
                                    metricname = 'Average Precision  (AP) @[ IoU = 0.50      | area =    all | maxDets = 100 ]'
                                    evalscore = val_results[catgname][
                                        metricname]
                                    writer.add_scalars(
                                        f'mAP@IoU=0.5 and area=all',
                                        {f'{catgname}': evalscore}, step)
                            except Exception as exption:
                                print("Unable to perform evaluation", exption)

                        if loss + opt.es_min_delta < best_loss:
                            best_loss = loss
                            best_epoch = epoch

                            save_checkpoint(
                                model,
                                f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                            )

                        model.train()

                        # Early stopping
                        if epoch - best_epoch > opt.es_patience > 0:
                            print(
                                '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                                .format(epoch, best_loss))
                            break


#########################################################EVAL#####################################################

# log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model,
                            f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                        )
                        print('checkpoint...')

                except Exception as exception:
                    print('[Error]', traceback.format_exc())
                    print(exception)
                    continue
            scheduler.step(np.mean(epoch_loss))
    except KeyboardInterrupt:
        save_checkpoint(
            model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
        writer.close()
    writer.close()
def train(opt):
    params = Params(opt.config)

    if params.num_gpus == 0:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = params.logdir
    opt.log_path = os.path.join(params.logdir, "tensorboard")
    os.makedirs(opt.saved_path, exist_ok=True)
    os.makedirs(opt.log_path, exist_ok=True)
    
    training_params = {
        "batch_size": opt.batch_size,
        "shuffle": True,
        "drop_last": True,
        "collate_fn": collater,
        "num_workers": opt.num_workers,
    }

    val_params = {
        "batch_size": opt.batch_size,
        "shuffle": False,
        "drop_last": True,
        "collate_fn": collater,
        "num_workers": opt.num_workers,
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
    training_set = CocoDataset(
        image_dir=params.image_dir,
        json_path=params.train_annotations,
        transform=transforms.Compose(
            [
                Normalizer(mean=params.mean, std=params.std),
                Augmenter(),
                Resizer(input_sizes[opt.compound_coef]),
            ]
        ),
    )
    training_generator = DataLoader(training_set, **training_params)

    if params.val_image_dir is None:
        params.val_image_dir = params.image_dir

    val_set = CocoDataset(
        image_dir=params.val_image_dir,
        json_path=params.val_annotations,
        transform=transforms.Compose(
            [Normalizer(mean=params.mean, std=params.std), Resizer(input_sizes[opt.compound_coef])]
        ),
    )
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(
        num_classes=len(params.obj_list),
        compound_coef=opt.compound_coef,
        ratios=eval(params.anchors_ratios),
        scales=eval(params.anchors_scales),
    )

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith(".pth"):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(os.path.basename(weights_path).split("_")[-1].split(".")[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f"[Warning] Ignoring {e}")
            print(
                "[Warning] Don't panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already."
            )

        print(
            f"[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}"
        )
    else:
        last_step = 0
        print("[Info] initializing weights...")
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ["EfficientNet", "BiFPN"]:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print("[Info] freezed backbone")

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data["img"]
                    annot = data["annot"]

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        "Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}".format(
                            step,
                            epoch,
                            opt.num_epochs,
                            iter + 1,
                            num_iter_per_epoch,
                            cls_loss.item(),
                            reg_loss.item(),
                            loss.item(),
                        )
                    )
                    writer.add_scalars("Loss", {"train": loss}, step)
                    writer.add_scalars("Regression_loss", {"train": reg_loss}, step)
                    writer.add_scalars("Classfication_loss", {"train": cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]["lr"]
                    writer.add_scalar("learning_rate", current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model, f"efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth"
                        )
                        print("checkpoint...")

                except Exception as e:
                    print("[Error]", traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data["img"]
                        annot = data["annot"]

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    "Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}".format(
                        epoch, opt.num_epochs, cls_loss, reg_loss, loss
                    )
                )
                writer.add_scalars("Loss", {"val": loss}, step)
                writer.add_scalars("Regression_loss", {"val": reg_loss}, step)
                writer.add_scalars("Classfication_loss", {"val": cls_loss}, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(model, f"efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth")

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        "[Info] Stop training at epoch {}. The lowest loss achieved is {}".format(
                            epoch, best_loss
                        )
                    )
                    break
    except KeyboardInterrupt:
        save_checkpoint(model, f"efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth")
        writer.close()
    writer.close()
input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
compound_coef = 3
project_name = "fujiseal_stain"
data_path = "./datasets/"
train = True

params = Params(f"projects/{project_name}.yml")

if train:
    dataset_name = params.train_set
else:
    dataset_name = params.val_set

image_dataset = CocoDataset(
    root_dir=os.path.join(data_path, params.project_name),
    set=dataset_name,
    transform=transforms.Compose([Resizer(input_sizes[compound_coef])]))

image_loader = DataLoader(image_dataset,
                          batch_size=1,
                          shuffle=False,
                          drop_last=False,
                          num_workers=0)

psum = torch.tensor([0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0])

start_time = time.time()
for inputs in image_loader:
    imgs = inputs['img']
    psum += imgs.sum(axis=[0, 1, 2])