def _get_train_data_loader(args):
    logger.info("Getting train data loader")
    dataset = CocoDataset(
        root_dir=args.data_dir,
        set="train",
        transform=transforms.Compose([
            Normalizer(mean=args.mean, std=args.std),
            Augmenter(),
            Resizer(_INPUT_SIZES[args.compound_coef]),
        ]),
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size * num_gpus,
        shuffle=True,
        drop_last=True,
        collate_fn=collater,
        num_workers=args.num_workers
        if args.num_workers >= 0 else args.batch_size * num_gpus,
    )
Пример #2
0
def train(opt):
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1356]
    training_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                     params.project_name),
                               set=params.train_set,
                               transform=transforms.Compose([
                                   Normalizer(mean=params.mean,
                                              std=params.std),
                                   Augmenter(),
                                   Resizer(input_sizes[opt.compound_coef])
                               ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                params.project_name),
                          set=params.val_set,
                          transform=transforms.Compose([
                              Normalizer(mean=params.mean, std=params.std),
                              Resizer(input_sizes[opt.compound_coef])
                          ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
            )

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['img']
                    annot = data['annot']

                    ################just GT check#########################
                    # img_sample = imgs[0,:,:,:]
                    # annot_sample = annot[0,:,:]
                    # img_out = img_sample.numpy()
                    # img_out = np.transpose(img_out, (1,2,0))
                    # img_out = cv2.cvtColor(img_out, cv2.COLOR_RGB2BGR)
                    # annot_out = annot_sample.numpy()
                    # count, _ = annot_out.shape
                    # for i in range(count):
                    #     if annot_out[i,4]  >= 0:
                    #         cv2.rectangle(img_out, (int(annot_out[i,0]),int(annot_out[i,1])), (int(annot_out[i,2]),int(annot_out[i,3])), (255,0,0), 1)
                    #
                    # cv2.imwrite("test.png", img_out*255)
                    #
                    #######################################################

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs,
                                               annot,
                                               obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch, opt.num_epochs, iter + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model,
                            f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                        )
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['img']
                        annot = data['annot']

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs,
                                                   annot,
                                                   obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch, opt.num_epochs, cls_loss, reg_loss, loss))
                writer.add_scalars('Loss', {'val': loss}, step)
                writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
                writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                   step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(
                        model,
                        f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                    )

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(
            model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
        writer.close()
    writer.close()
Пример #3
0
def train(opt):
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
    training_set = CocoDataset(root_dir=opt.data_path + params.project_name,
                               set=params.train_set,
                               transform=transforms.Compose([
                                   Normalizer(mean=params.mean,
                                              std=params.std),
                                   Augmenter(),
                                   Resizer(input_sizes[opt.compound_coef])
                               ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=opt.data_path + params.project_name,
                          set=params.val_set,
                          transform=transforms.Compose([
                              Normalizer(mean=params.mean, std=params.std),
                              Resizer(input_sizes[opt.compound_coef])
                          ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_anchors=9,
                                 num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef)

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0
        model.load_state_dict(torch.load(weights_path))
        print(
            f'loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    if params.num_gpus > 0:
        model = model.cuda()
        model = CustomDataParallel(model, params.num_gpus)

    optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    criterion = FocalLoss()

    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)
    for epoch in range(opt.num_epochs):
        try:
            model.train()
            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                try:
                    imgs = data['img']
                    annot = data['annot']

                    if params.num_gpus > 0:
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    _, regression, classification, anchors = model(imgs)

                    cls_loss, reg_loss = criterion(
                        classification,
                        regression,
                        anchors,
                        annot,
                        # imgs=imgs, obj_list=params.obj_list  # uncomment this to debug
                    )

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch + 1, opt.num_epochs, iter + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                except Exception as e:
                    print(traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if step % opt.save_interval == 0 and step > 0:
                save_checkpoint(
                    model,
                    f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['img']
                        annot = data['annot']

                        if params.num_gpus > 0:
                            annot = annot.cuda()
                        _, regression, classification, anchors = model(imgs)
                        cls_loss, reg_loss = criterion(classification,
                                                       regression, anchors,
                                                       annot)

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch + 1, opt.num_epochs, cls_loss, reg_loss,
                            loss.mean()))
                writer.add_scalars('Total_loss', {'val': loss}, step)
                writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
                writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                   step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(
                        model,
                        f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                    )

                    # onnx export is not tested.
                    # dummy_input = torch.rand(opt.batch_size, 3, 512, 512)
                    # if torch.cuda.is_available():
                    #     dummy_input = dummy_input.cuda()
                    # if isinstance(model, nn.DataParallel):
                    #     model.module.backbone_net.model.set_swish(memory_efficient=False)
                    #
                    #     torch.onnx.export(model.module, dummy_input,
                    #                       os.path.join(opt.saved_path, 'signatrix_efficientdet_coco.onnx'),
                    #                       verbose=False)
                    #     model.module.backbone_net.model.set_swish(memory_efficient=True)
                    # else:
                    #     model.backbone_net.model.set_swish(memory_efficient=False)
                    #
                    #     torch.onnx.export(model, dummy_input,
                    #                       os.path.join(opt.saved_path, 'signatrix_efficientdet_coco.onnx'),
                    #                       verbose=False)
                    #     model.backbone_net.model.set_swish(memory_efficient=True)

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        'Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, loss))
                    break
            writer.close()
        except KeyboardInterrupt:
            save_checkpoint(
                model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
Пример #4
0
def train(opt):
    '''
    Input: get_args()
    Function: Train the model.
    '''
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    # evaluation json file
    pred_folder = f'{OPT.data_path}/{OPT.project}/predictions'
    os.makedirs(pred_folder, exist_ok=True)
    evaluation_pred_file = f'{pred_folder}/instances_bbox_results.json'

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
    training_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                     params.project_name),
                               set=params.train_set,
                               transform=torchvision.transforms.Compose([
                                   Normalizer(mean=params.mean,
                                              std=params.std),
                                   Augmenter(),
                                   Resizer(input_sizes[opt.compound_coef])
                               ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = CocoDataset(root_dir=os.path.join(opt.data_path,
                                                params.project_name),
                          set=params.val_set,
                          transform=torchvision.transforms.Compose([
                              Normalizer(mean=params.mean, std=params.std),
                              Resizer(input_sizes[opt.compound_coef])
                          ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except Exception as exception:
            last_step = 0

        try:
            _ = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as rerror:
            print(f'[Warning] Ignoring {rerror}')
            print('[Warning] Don\'t panic if you see this, '\
                  'this might be because you load a pretrained weights with different number of classes.'\
                  ' The rest of the weights should be loaded already.')

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(mdl):
            classname = mdl.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in mdl.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)
    num_val_iter_per_epoch = len(val_generator)
    # Limit the no.of preds to #images in val.
    # Here, I averaged the #obj to 5 for computational efficacy
    if opt.max_preds_toeval > 0:
        opt.max_preds_toeval = len(val_generator) * opt.batch_size * 5

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iternum, data in enumerate(progress_bar):
                if iternum < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['img']
                    annot = data['annot']
                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    if iternum % int(num_iter_per_epoch *
                                     (opt.eval_percent_epoch / 100)) != 0:
                        model.debug = False
                        cls_loss, reg_loss, _ = model(imgs,
                                                      annot,
                                                      obj_list=params.obj_list)
                    else:
                        model.debug = True
                        cls_loss, reg_loss, imgs_labelled = model(
                            imgs, annot, obj_list=params.obj_list)

                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch, opt.num_epochs, iternum + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    if iternum % int(
                            num_iter_per_epoch *
                        (opt.eval_percent_epoch / 100)) == 0 and step > 0:
                        # create grid of images
                        imgs_labelled = np.asarray(imgs_labelled)
                        imgs_labelled = torch.from_numpy(
                            imgs_labelled)  # (N, H, W, C)
                        imgs_labelled.transpose_(1, 3)  # (N, C, H, W)
                        imgs_labelled.transpose_(2, 3)
                        img_grid = torchvision.utils.make_grid(imgs_labelled)
                        # write to tensorboard
                        writer.add_image('Training_images',
                                         img_grid,
                                         global_step=step)
                        #########################################################start EVAL#####################################################
                        model.eval()
                        model.debug = False  # Don't print images in tensorboard now.

                        # remove json
                        if os.path.exists(evaluation_pred_file):
                            os.remove(evaluation_pred_file)

                        loss_regression_ls = []
                        loss_classification_ls = []
                        model.evalresults = [
                        ]  # Empty the results for next evaluation.
                        imgs_to_viz = []
                        num_validation_steps = int(
                            num_val_iter_per_epoch *
                            (opt.eval_sampling_percent / 100))
                        for valiternum, valdata in enumerate(val_generator):
                            with torch.no_grad():
                                imgs = valdata['img']
                                annot = valdata['annot']
                                resizing_imgs_scales = valdata['scale']
                                new_ws = valdata['new_w']
                                new_hs = valdata['new_h']
                                imgs_ids = valdata['img_id']

                                if params.num_gpus >= 1:
                                    imgs = imgs.cuda()
                                    annot = annot.cuda()

                                if valiternum % (num_validation_steps //
                                                 (opt.num_visualize_images //
                                                  opt.batch_size)) != 0:
                                    model.debug = False
                                    cls_loss, reg_loss, _ = model(
                                        imgs,
                                        annot,
                                        obj_list=params.obj_list,
                                        resizing_imgs_scales=
                                        resizing_imgs_scales,
                                        new_ws=new_ws,
                                        new_hs=new_hs,
                                        imgs_ids=imgs_ids)
                                else:
                                    model.debug = True
                                    cls_loss, reg_loss, val_imgs_labelled = model(
                                        imgs,
                                        annot,
                                        obj_list=params.obj_list,
                                        resizing_imgs_scales=
                                        resizing_imgs_scales,
                                        new_ws=new_ws,
                                        new_hs=new_hs,
                                        imgs_ids=imgs_ids)

                                    imgs_to_viz += list(val_imgs_labelled)

                                loss_classification_ls.append(cls_loss.item())
                                loss_regression_ls.append(reg_loss.item())

                            if valiternum > (num_validation_steps):
                                break

                        cls_loss = np.mean(loss_classification_ls)
                        reg_loss = np.mean(loss_regression_ls)
                        loss = cls_loss + reg_loss

                        print(
                            'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                            .format(epoch, opt.num_epochs, cls_loss, reg_loss,
                                    loss))
                        writer.add_scalars('Loss', {'val': loss}, step)
                        writer.add_scalars('Regression_loss',
                                           {'val': reg_loss}, step)
                        writer.add_scalars('Classfication_loss',
                                           {'val': cls_loss}, step)
                        # create grid of images
                        val_imgs_labelled = np.asarray(imgs_to_viz)
                        val_imgs_labelled = torch.from_numpy(
                            val_imgs_labelled)  # (N, H, W, C)
                        val_imgs_labelled.transpose_(1, 3)  # (N, C, H, W)
                        val_imgs_labelled.transpose_(2, 3)
                        val_img_grid = torchvision.utils.make_grid(
                            val_imgs_labelled, nrow=2)
                        # write to tensorboard
                        writer.add_image('Eval_Images', val_img_grid, \
                                         global_step=(step))

                        if opt.max_preds_toeval > 0:
                            json.dump(model.evalresults,
                                      open(evaluation_pred_file, 'w'),
                                      indent=4)
                            try:
                                val_results = calc_mAP_fin(params.project_name,\
                                                        params.val_set, evaluation_pred_file, \
                                                        val_gt=f'{OPT.data_path}/{OPT.project}/annotations/instances_{params.val_set}.json')

                                for catgname in val_results:
                                    metricname = 'Average Precision  (AP) @[ IoU = 0.50      | area =    all | maxDets = 100 ]'
                                    evalscore = val_results[catgname][
                                        metricname]
                                    writer.add_scalars(
                                        f'mAP@IoU=0.5 and area=all',
                                        {f'{catgname}': evalscore}, step)
                            except Exception as exption:
                                print("Unable to perform evaluation", exption)

                        if loss + opt.es_min_delta < best_loss:
                            best_loss = loss
                            best_epoch = epoch

                            save_checkpoint(
                                model,
                                f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                            )

                        model.train()

                        # Early stopping
                        if epoch - best_epoch > opt.es_patience > 0:
                            print(
                                '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                                .format(epoch, best_loss))
                            break


#########################################################EVAL#####################################################

# log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model,
                            f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                        )
                        print('checkpoint...')

                except Exception as exception:
                    print('[Error]', traceback.format_exc())
                    print(exception)
                    continue
            scheduler.step(np.mean(epoch_loss))
    except KeyboardInterrupt:
        save_checkpoint(
            model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
        writer.close()
    writer.close()
    def start_training(self):
        if self.system_dict["params"]["num_gpus"] == 0:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

        if torch.cuda.is_available():
            torch.cuda.manual_seed(42)
        else:
            torch.manual_seed(42)

        self.system_dict["params"]["saved_path"] = self.system_dict["params"][
            "saved_path"] + "/" + self.system_dict["params"][
                "project_name"] + "/"
        self.system_dict["params"]["log_path"] = self.system_dict["params"][
            "log_path"] + "/" + self.system_dict["params"][
                "project_name"] + "/tensorboard/"
        os.makedirs(self.system_dict["params"]["saved_path"], exist_ok=True)
        os.makedirs(self.system_dict["params"]["log_path"], exist_ok=True)

        training_params = {
            'batch_size': self.system_dict["params"]["batch_size"],
            'shuffle': True,
            'drop_last': True,
            'collate_fn': collater,
            'num_workers': self.system_dict["params"]["num_workers"]
        }

        val_params = {
            'batch_size': self.system_dict["params"]["batch_size"],
            'shuffle': False,
            'drop_last': True,
            'collate_fn': collater,
            'num_workers': self.system_dict["params"]["num_workers"]
        }

        input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
        training_set = CocoDataset(
            self.system_dict["dataset"]["train"]["root_dir"],
            self.system_dict["dataset"]["train"]["coco_dir"],
            self.system_dict["dataset"]["train"]["img_dir"],
            set_dir=self.system_dict["dataset"]["train"]["set_dir"],
            transform=transforms.Compose([
                Normalizer(mean=self.system_dict["params"]["mean"],
                           std=self.system_dict["params"]["std"]),
                Augmenter(),
                Resizer(
                    input_sizes[self.system_dict["params"]["compound_coef"]])
            ]))
        training_generator = DataLoader(training_set, **training_params)

        if (self.system_dict["dataset"]["val"]["status"]):
            val_set = CocoDataset(
                self.system_dict["dataset"]["val"]["root_dir"],
                self.system_dict["dataset"]["val"]["coco_dir"],
                self.system_dict["dataset"]["val"]["img_dir"],
                set_dir=self.system_dict["dataset"]["val"]["set_dir"],
                transform=transforms.Compose([
                    Normalizer(self.system_dict["params"]["mean"],
                               self.system_dict["params"]["std"]),
                    Resizer(input_sizes[self.system_dict["params"]
                                        ["compound_coef"]])
                ]))
            val_generator = DataLoader(val_set, **val_params)

        print("")
        print("")
        model = EfficientDetBackbone(
            num_classes=len(self.system_dict["params"]["obj_list"]),
            compound_coef=self.system_dict["params"]["compound_coef"],
            ratios=eval(self.system_dict["params"]["anchors_ratios"]),
            scales=eval(self.system_dict["params"]["anchors_scales"]))

        os.makedirs("pretrained_weights", exist_ok=True)

        if (self.system_dict["params"]["compound_coef"] == 0):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d0.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 1):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d1.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 2):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d2.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 3):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d3.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 4):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d4.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 5):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d5.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 6):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d6.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)
        elif (self.system_dict["params"]["compound_coef"] == 7):
            if (not os.path.isfile(
                    self.system_dict["params"]["load_weights"])):
                print("Downloading weights")
                cmd = "wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d7.pth -O " + \
                            self.system_dict["params"]["load_weights"]
                os.system(cmd)

        # load last weights
        if self.system_dict["params"]["load_weights"] is not None:
            if self.system_dict["params"]["load_weights"].endswith('.pth'):
                weights_path = self.system_dict["params"]["load_weights"]
            else:
                weights_path = get_last_weights(
                    self.system_dict["params"]["saved_path"])
            try:
                last_step = int(
                    os.path.basename(weights_path).split('_')[-1].split('.')
                    [0])
            except:
                last_step = 0

            try:
                ret = model.load_state_dict(torch.load(weights_path),
                                            strict=False)
            except RuntimeError as e:
                print(f'[Warning] Ignoring {e}')
                print(
                    '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
                )

            print(
                f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
            )
        else:
            last_step = 0
            print('[Info] initializing weights...')
            init_weights(model)

        print("")
        print("")

        # freeze backbone if train head_only
        if self.system_dict["params"]["head_only"]:

            def freeze_backbone(m):
                classname = m.__class__.__name__
                for ntl in ['EfficientNet', 'BiFPN']:
                    if ntl in classname:
                        for param in m.parameters():
                            param.requires_grad = False

            model.apply(freeze_backbone)
            print('[Info] freezed backbone')

        print("")
        print("")

        if self.system_dict["params"]["num_gpus"] > 1 and self.system_dict[
                "params"]["batch_size"] // self.system_dict["params"][
                    "num_gpus"] < 4:
            model.apply(replace_w_sync_bn)
            use_sync_bn = True
        else:
            use_sync_bn = False

        writer = SummaryWriter(
            self.system_dict["params"]["log_path"] +
            f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

        model = ModelWithLoss(model, debug=self.system_dict["params"]["debug"])

        if self.system_dict["params"]["num_gpus"] > 0:
            model = model.cuda()
            if self.system_dict["params"]["num_gpus"] > 1:
                model = CustomDataParallel(
                    model, self.system_dict["params"]["num_gpus"])
                if use_sync_bn:
                    patch_replication_callback(model)

        if self.system_dict["params"]["optim"] == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(),
                                          self.system_dict["params"]["lr"])
        else:
            optimizer = torch.optim.SGD(model.parameters(),
                                        self.system_dict["params"]["lr"],
                                        momentum=0.9,
                                        nesterov=True)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               patience=3,
                                                               verbose=True)

        epoch = 0
        best_loss = 1e5
        best_epoch = 0
        step = max(0, last_step)
        model.train()

        num_iter_per_epoch = len(training_generator)

        try:
            for epoch in range(self.system_dict["params"]["num_epochs"]):
                last_epoch = step // num_iter_per_epoch
                if epoch < last_epoch:
                    continue

                epoch_loss = []
                progress_bar = tqdm(training_generator)
                for iter, data in enumerate(progress_bar):
                    if iter < step - last_epoch * num_iter_per_epoch:
                        progress_bar.update()
                        continue
                    try:
                        imgs = data['img']
                        annot = data['annot']

                        if self.system_dict["params"]["num_gpus"] == 1:
                            # if only one gpu, just send it to cuda:0
                            # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        optimizer.zero_grad()
                        cls_loss, reg_loss = model(
                            imgs,
                            annot,
                            obj_list=self.system_dict["params"]["obj_list"])
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss.backward()
                        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                        optimizer.step()

                        epoch_loss.append(float(loss))

                        progress_bar.set_description(
                            'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                            .format(step, epoch,
                                    self.system_dict["params"]["num_epochs"],
                                    iter + 1, num_iter_per_epoch,
                                    cls_loss.item(), reg_loss.item(),
                                    loss.item()))
                        writer.add_scalars('Loss', {'train': loss}, step)
                        writer.add_scalars('Regression_loss',
                                           {'train': reg_loss}, step)
                        writer.add_scalars('Classfication_loss',
                                           {'train': cls_loss}, step)

                        # log learning_rate
                        current_lr = optimizer.param_groups[0]['lr']
                        writer.add_scalar('learning_rate', current_lr, step)

                        step += 1

                        if step % self.system_dict["params"][
                                "save_interval"] == 0 and step > 0:
                            self.save_checkpoint(
                                model,
                                f'efficientdet-d{self.system_dict["params"]["compound_coef"]}_trained.pth'
                            )
                            #print('checkpoint...')

                    except Exception as e:
                        print('[Error]', traceback.format_exc())
                        print(e)
                        continue
                scheduler.step(np.mean(epoch_loss))

                if epoch % self.system_dict["params"][
                        "val_interval"] == 0 and self.system_dict["dataset"][
                            "val"]["status"]:
                    print("Running validation")
                    model.eval()
                    loss_regression_ls = []
                    loss_classification_ls = []
                    for iter, data in enumerate(val_generator):
                        with torch.no_grad():
                            imgs = data['img']
                            annot = data['annot']

                            if self.system_dict["params"]["num_gpus"] == 1:
                                imgs = imgs.cuda()
                                annot = annot.cuda()

                            cls_loss, reg_loss = model(
                                imgs,
                                annot,
                                obj_list=self.system_dict["params"]
                                ["obj_list"])
                            cls_loss = cls_loss.mean()
                            reg_loss = reg_loss.mean()

                            loss = cls_loss + reg_loss
                            if loss == 0 or not torch.isfinite(loss):
                                continue

                            loss_classification_ls.append(cls_loss.item())
                            loss_regression_ls.append(reg_loss.item())

                    cls_loss = np.mean(loss_classification_ls)
                    reg_loss = np.mean(loss_regression_ls)
                    loss = cls_loss + reg_loss

                    print(
                        'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                        .format(epoch,
                                self.system_dict["params"]["num_epochs"],
                                cls_loss, reg_loss, loss))
                    writer.add_scalars('Loss', {'val': loss}, step)
                    writer.add_scalars('Regression_loss', {'val': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                       step)

                    if loss + self.system_dict["params"][
                            "es_min_delta"] < best_loss:
                        best_loss = loss
                        best_epoch = epoch

                        self.save_checkpoint(
                            model,
                            f'efficientdet-d{self.system_dict["params"]["compound_coef"]}_trained.pth'
                        )

                    model.train()

                    # Early stopping
                    if epoch - best_epoch > self.system_dict["params"][
                            "es_patience"] > 0:
                        print(
                            '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                            .format(epoch, best_loss))
                        break
        except KeyboardInterrupt:
            self.save_checkpoint(
                model,
                f'efficientdet-d{self.system_dict["params"]["compound_coef"]}_trained.pth'
            )
            writer.close()
        writer.close()

        print("")
        print("")
        print("Training complete")
def train(opt):
    params = Params(opt.config)

    if params.num_gpus == 0:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = params.logdir
    opt.log_path = os.path.join(params.logdir, "tensorboard")
    os.makedirs(opt.saved_path, exist_ok=True)
    os.makedirs(opt.log_path, exist_ok=True)
    
    training_params = {
        "batch_size": opt.batch_size,
        "shuffle": True,
        "drop_last": True,
        "collate_fn": collater,
        "num_workers": opt.num_workers,
    }

    val_params = {
        "batch_size": opt.batch_size,
        "shuffle": False,
        "drop_last": True,
        "collate_fn": collater,
        "num_workers": opt.num_workers,
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
    training_set = CocoDataset(
        image_dir=params.image_dir,
        json_path=params.train_annotations,
        transform=transforms.Compose(
            [
                Normalizer(mean=params.mean, std=params.std),
                Augmenter(),
                Resizer(input_sizes[opt.compound_coef]),
            ]
        ),
    )
    training_generator = DataLoader(training_set, **training_params)

    if params.val_image_dir is None:
        params.val_image_dir = params.image_dir

    val_set = CocoDataset(
        image_dir=params.val_image_dir,
        json_path=params.val_annotations,
        transform=transforms.Compose(
            [Normalizer(mean=params.mean, std=params.std), Resizer(input_sizes[opt.compound_coef])]
        ),
    )
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(
        num_classes=len(params.obj_list),
        compound_coef=opt.compound_coef,
        ratios=eval(params.anchors_ratios),
        scales=eval(params.anchors_scales),
    )

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith(".pth"):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(os.path.basename(weights_path).split("_")[-1].split(".")[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f"[Warning] Ignoring {e}")
            print(
                "[Warning] Don't panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already."
            )

        print(
            f"[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}"
        )
    else:
        last_step = 0
        print("[Info] initializing weights...")
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ["EfficientNet", "BiFPN"]:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print("[Info] freezed backbone")

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data["img"]
                    annot = data["annot"]

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        "Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}".format(
                            step,
                            epoch,
                            opt.num_epochs,
                            iter + 1,
                            num_iter_per_epoch,
                            cls_loss.item(),
                            reg_loss.item(),
                            loss.item(),
                        )
                    )
                    writer.add_scalars("Loss", {"train": loss}, step)
                    writer.add_scalars("Regression_loss", {"train": reg_loss}, step)
                    writer.add_scalars("Classfication_loss", {"train": cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]["lr"]
                    writer.add_scalar("learning_rate", current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model, f"efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth"
                        )
                        print("checkpoint...")

                except Exception as e:
                    print("[Error]", traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data["img"]
                        annot = data["annot"]

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    "Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}".format(
                        epoch, opt.num_epochs, cls_loss, reg_loss, loss
                    )
                )
                writer.add_scalars("Loss", {"val": loss}, step)
                writer.add_scalars("Regression_loss", {"val": reg_loss}, step)
                writer.add_scalars("Classfication_loss", {"val": cls_loss}, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(model, f"efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth")

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        "[Info] Stop training at epoch {}. The lowest loss achieved is {}".format(
                            epoch, best_loss
                        )
                    )
                    break
    except KeyboardInterrupt:
        save_checkpoint(model, f"efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth")
        writer.close()
    writer.close()
Пример #7
0
def train(opt):
    params = Params(f'projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/{params.project_name}/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
    train_df = pd.read_csv(os.path.join(params.data_dir, 'train.csv'))
    train_df, val_df = get_train_val(train_df)

    training_set = WheatDataset(dataframe=train_df,
                                image_dir=os.path.join(params.data_dir,
                                                       params.train_set),
                                transforms=transforms.Compose([
                                    Normalizer(mean=params.mean,
                                               std=params.std),
                                    Augmenter(),
                                    Resizer(input_sizes[opt.compound_coef])
                                ]))
    training_generator = DataLoader(training_set, **training_params)

    val_set = WheatDataset(dataframe=val_df,
                           image_dir=os.path.join(params.data_dir,
                                                  params.train_set),
                           transforms=transforms.Compose([
                               Normalizer(mean=params.mean, std=params.std),
                               Augmenter(),
                               Resizer(input_sizes[opt.compound_coef])
                           ]))
    val_generator = DataLoader(val_set, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=opt.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
            )

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_We wish we could give free compute without any bounds, because they help a lot of people do deep learning who otherwise lack access to GPUs. Unfortunately, we have a finite budget, and we've started hitting our limit.only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')

    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    writer = SummaryWriter(
        opt.log_path +
        f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['image']
                    annot = data['bboxes']

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs,
                                               annot,
                                               obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch, opt.num_epochs, iter + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))
                    writer.add_scalars('Loss', {'train': loss}, step)
                    writer.add_scalars('Regression_loss', {'train': reg_loss},
                                       step)
                    writer.add_scalars('Classfication_loss',
                                       {'train': cls_loss}, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model,
                            f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                        )
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['image']
                        annot = data['bboxes']

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs,
                                                   annot,
                                                   obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch, opt.num_epochs, cls_loss, reg_loss, loss))
                writer.add_scalars('Loss', {'val': loss}, step)
                writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
                writer.add_scalars('Classfication_loss', {'val': cls_loss},
                                   step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(
                        model,
                        f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth'
                    )

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, best_loss))
                    break

    except KeyboardInterrupt:
        save_checkpoint(
            model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
        writer.close()
    writer.close()
Пример #8
0
def main(args):
    print("Hi")
    if (os.path.exists(f"{args.weight_path}/test_log.out")):
        os.remove(f"{args.weight_path}/test_log.out")
    assert args.weight_path, 'must indicate the path of pre-trained weight'
    params = Params(f'projects/eye.yml')

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    test_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': args.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=args.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))
    init_weights(model)
    model = ModelWithLoss(model)
    model = model.cuda()

    if args.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), args.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=5,
                                                           verbose=True)

    model.model.load_state_dict(
        torch.load(f'{args.weight_path}/pre_trained_weight.pth')
        ['model_state_dict'])
    optimizer.load_state_dict(
        torch.load(f'{args.weight_path}/pre_trained_weight.pth')
        ['optimizer_state_dict'])
    scheduler.load_state_dict(
        torch.load(f'{args.weight_path}/pre_trained_weight.pth')
        ['scheduler_state_dict'])

    test_img_list = glob.glob(f'{args.dataset_path}/test/*')
    test_anno_txt_path = f'{args.dataset_path}/test.txt'

    test_transform = transforms.Compose(
        [  # Normalizer(mean=params.mean, std=params.std),
            Augmenter(),
            Normalizer(mean=params.mean, std=params.std),
            Resizer(input_sizes[args.compound_coef])
        ])

    test_set = EyeDataset(test_img_list, test_anno_txt_path, test_transform)
    test_generator = DataLoader(test_set, **test_params)

    model.eval()
    with torch.no_grad():
        total = 0
        total_correct = 0
        total_loss_ls = []
        for data in test_generator:
            imgs = data['img'].cuda()
            annot = data['annot'].cuda()

            reg_loss, cls_head_loss, cls_correct_num, total_num = model(
                imgs, annot, obj_list=params.obj_list)

            total_correct += cls_correct_num
            total += total_num
            reg_loss = reg_loss.mean()
            loss = reg_loss + cls_head_loss
            total_loss_ls.append(loss.item())
        total_loss = np.mean(total_loss_ls)
        with open(f"{args.weight_path}/test_log.out", 'a') as fp:
            fp.write(
                f'Testing loss: {total_loss:.6f} | acc: {total_correct / total * 100:.2f}\n'
            )
Пример #9
0
def train(args):
    assert args.weight_path, 'must indicate the path of initial weight'
    if (os.path.exists(f'{args.weight_path}/train_log.txt')):
        os.remove(f'{args.weight_path}/train_log.txt')
    if (os.path.exists(f'{args.weight_path}/pre_trained_weight.pth')):
        os.remove(f'{args.weight_path}/pre_trained_weight.pth')
    print("Hi")
    present_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    params = Params(f'projects/eye.yml')

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    train_params = {'batch_size': args.batch_size,
                       'shuffle': True,
                       'drop_last': True,
                       'collate_fn': collater,
                       'num_workers': args.num_workers}

    val_params = {'batch_size': args.batch_size,
                  'shuffle': False,
                  'drop_last': True,
                  'collate_fn': collater,
                  'num_workers': args.num_workers}
    
    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
    
    model = EfficientDetBackbone(num_classes=len(params.obj_list), compound_coef=args.compound_coef,
                                 ratios=eval(params.anchors_ratios), scales=eval(params.anchors_scales))
    init_weights(model)

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model)
    model = model.cuda()

    if args.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), args.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=args.patience, verbose=True)  # unit is epoch
    
    img_list = glob.glob(f"{args.dataset_path}/train/*")
    normal_img_list = []
    yellow_img_list = []
    for img in img_list:
        if (img.find("n_") != -1):
            normal_img_list.append(img)
        else:
            yellow_img_list.append(img)
    random.shuffle(normal_img_list)
    random.shuffle(yellow_img_list)
    normal_val_num = int(len(normal_img_list) / 5)
    yellow_val_num = int(len(yellow_img_list) / 5)
    train_img_list = normal_img_list[normal_val_num:] + yellow_img_list[yellow_val_num:]
    val_img_list = normal_img_list[:normal_val_num] + yellow_img_list[:yellow_val_num]
    train_anno_txt_path = f"{args.dataset_path}/train.txt"
    val_anno_txt_path = f"{args.dataset_path}/train.txt"

    train_transform = transforms.Compose([# Normalizer(mean=params.mean, std=params.std),
                                    Augmenter(),
                                    randomScaleWidth(),
                                    randomBlur(),
                                    # randomBrightness(),
                                    # randomHue(),
                                    # randomSaturation(),
                                    Normalizer(mean=params.mean, std=params.std),
                                    Resizer(input_sizes[args.compound_coef])])

    val_transform = transforms.Compose([# Normalizer(mean=params.mean, std=params.std),
                                    Augmenter(),
                                    Normalizer(mean=params.mean, std=params.std),
                                    Resizer(input_sizes[args.compound_coef])])

    train_set = EyeDataset(train_img_list, train_anno_txt_path, train_transform)
    val_set = EyeDataset(val_img_list, val_anno_txt_path, val_transform)
    
    train_generator = DataLoader(train_set, **train_params)
    val_generator = DataLoader(val_set, **val_params)
   
    model.model.load_state_dict(torch.load(f'{args.weight_path}/init_weight.pth')["model_state_dict"])
    optimizer.load_state_dict(torch.load(f'{args.weight_path}/init_weight.pth')["optimizer_state_dict"])
    scheduler.load_state_dict(torch.load(f'{args.weight_path}/init_weight.pth')["scheduler_state_dict"])
    model.train()

    best_val_loss = 1e5
    for epoch in range(args.epoch):
        model.train()
        total_loss_ls = []
        total_correct = 0
        total = 0
        for data in train_generator:
            imgs = data['img'].cuda()
            annot = data['annot'].cuda()

            optimizer.zero_grad()
            reg_loss, cls_head_loss, cls_correct_num, total_num = model(imgs, annot, obj_list=params.obj_list)
            total_correct += cls_correct_num
            total += total_num
            reg_loss = reg_loss.mean()
            loss = cls_head_loss + reg_loss
            total_loss_ls.append(loss.item())

            if (loss == 0 or not torch.isfinite(loss)):
                continue

            loss.backward()
            optimizer.step()
        total_loss = np.mean(total_loss_ls)
        scheduler.step(total_loss)

        with open(f'{args.weight_path}/train_log.txt', 'a') as fp:
            fp.write(f'Epoch: {epoch}  loss: {total_loss:.6f} | acc: {total_correct / total * 100:.2f}\n')

        model.eval()
        with torch.no_grad():
            total = 0
            total_correct = 0
            total_loss_ls = []
            for data in val_generator:
                imgs = data['img'].cuda()
                annot = data['annot'].cuda()

                reg_loss, cls_head_loss, cls_correct_num, total_num = model(imgs, annot, obj_list=params.obj_list)
                total += total_num
                total_correct += cls_correct_num
                reg_loss = reg_loss.mean()
                loss = cls_head_loss + reg_loss
                total_loss_ls.append(loss.item())

            total_loss = np.mean(total_loss_ls)
            with open(f'{args.weight_path}/train_log.txt', 'a') as fp:
                fp.write(f'Epoch: {epoch}  loss: {total_loss:.6f} | acc: {total_correct / total * 100:.2f}\n\n')
            if (total_loss < best_val_loss):
                best_val_loss = total_loss
                torch.save({
                    "model_state_dict": model.model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    }, f"{args.weight_path}/pre_trained_weight.pth")
def train(args):
    print("Hi")
    present_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    params = Params(f'projects/eye.yml')

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.cuda.manual_seed(20)
    torch.cuda.manual_seed_all(20)
    np.random.seed(20)
    random.seed(20)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    prepare_dir(args, present_time)

    training_params = {
        'batch_size': args.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': args.num_workers
    }

    val_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': args.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=args.compound_coef,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    '''
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')

        print(f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}')
    else:
        last_step = 0
        
        print('[Info] initializing weights...')
        init_weights(model)
    '''
    init_weights(model)

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model)
    model = model.cuda()

    if args.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), args.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=args.patience, verbose=True)  # unit is epoch

    torch.save(
        {
            "model_state_dict": model.model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
        }, f"{args.saved_path}/init_weight.pth")

    k = 10
    train_img_list = glob.glob(f"{args.dataset_path}/train/*")
    normal_img_list = []
    yellow_img_list = []
    for img in train_img_list:
        if (img.find('n_') != -1):
            normal_img_list.append(img)
        else:
            yellow_img_list.append(img)
    random.shuffle(normal_img_list)
    random.shuffle(yellow_img_list)
    normal_part_num = math.ceil(len(normal_img_list) / k)
    yellow_part_num = math.ceil(len(yellow_img_list) / k)

    last_acc = []
    last_loss = []
    for i in range(k):
        best_loss = 1e5

        model.model.load_state_dict(
            torch.load(f"{args.saved_path}/init_weight.pth")
            ["model_state_dict"])
        optimizer.load_state_dict(
            torch.load(f"{args.saved_path}/init_weight.pth")
            ["optimizer_state_dict"])
        scheduler.load_state_dict(
            torch.load(f"{args.saved_path}/init_weight.pth")
            ["scheduler_state_dict"])
        model.train()

        sub_train_img_list = normal_img_list[:i * normal_part_num] + normal_img_list[
            (i + 1) *
            normal_part_num:] + yellow_img_list[:i *
                                                yellow_part_num] + yellow_img_list[
                                                    (i + 1) * yellow_part_num:]
        sub_test_img_list = normal_img_list[i * normal_part_num:(
            i + 1) * normal_part_num] + yellow_img_list[i * yellow_part_num:
                                                        (i + 1) *
                                                        yellow_part_num]
        random.shuffle(sub_train_img_list)
        random.shuffle(sub_test_img_list)
        print("---")
        for img in sub_test_img_list:
            print(img)
        print("---")

        train_anno_txt_path = f"{args.dataset_path}/train.txt"
        test_anno_txt_path = f"{args.dataset_path}/train.txt"

        train_transform = transforms.Compose(
            [  # Normalizer(mean=params.mean, std=params.std),
                Augmenter(),
                randomScaleWidth(),
                randomBlur(),
                randomBrightness(),
                randomHue(),
                randomSaturation(),
                Normalizer(mean=params.mean, std=params.std),
                Resizer(input_sizes[args.compound_coef])
            ])
        test_transform = transforms.Compose(
            [  # Normalizer(mean=params.mean, std=params.std),
                Augmenter(),
                Normalizer(mean=params.mean, std=params.std),
                Resizer(input_sizes[args.compound_coef])
            ])

        train_set = EyeDataset(sub_train_img_list, train_anno_txt_path,
                               train_transform)
        test_set = EyeDataset(sub_test_img_list, test_anno_txt_path,
                              test_transform)
        training_generator = DataLoader(train_set, **training_params)
        val_generator = DataLoader(test_set, **val_params)

        for epoch in range(args.epoch):
            model.train()
            total_correct = 0
            total = 0
            total_loss_ls = []
            for data in training_generator:
                imgs = data['img']
                annot = data['annot']

                imgs = imgs.cuda()
                annot = annot.cuda()

                optimizer.zero_grad()
                reg_loss, cls_head_loss, cls_correct_num, total_num = model(
                    imgs, annot, obj_list=params.obj_list)
                total_correct += cls_correct_num
                total += total_num
                reg_loss = reg_loss.mean()
                loss = cls_head_loss + reg_loss
                total_loss_ls.append(loss.item())

                if loss == 0 or not torch.isfinite(loss):
                    continue

                loss.backward()
                optimizer.step()
            total_loss = np.mean(total_loss_ls)
            scheduler.step(total_loss)
            with open(f'./logs/{present_time}/cv_log.txt', 'a') as fp:
                fp.write(f"Epoch: {i}/{epoch}/{args.epoch}\n")
                fp.write(
                    f"Training loss: {total_loss:.6f} | acc: {total_correct / total * 100:.2f}\n"
                )

            model.eval()
            with torch.no_grad():
                total = 0
                total_correct = 0
                total_loss_ls = []
                for data in val_generator:
                    imgs = data['img'].cuda()
                    annot = data['annot'].cuda()

                    reg_loss, cls_head_loss, cls_correct_num, total_num = model(
                        imgs, annot, obj_list=params.obj_list)
                    total_correct += cls_correct_num
                    total += total_num
                    reg_loss = reg_loss.mean()
                    loss = reg_loss + cls_head_loss
                    total_loss_ls.append(loss.item())
                total_loss = np.mean(total_loss_ls)

                with open(f'./logs/{present_time}/cv_log.txt', 'a') as fp:
                    fp.write(
                        f"Testing loss: {total_loss:.6f} | acc: {total_correct / total * 100:.2f}\n\n"
                    )

                if (epoch == args.epoch - 1):
                    last_loss.append(total_loss)
                    last_acc.append(total_correct / total * 100)

    with open(f'./logs/{present_time}/cv_log.txt', 'a') as fp:
        fp.write("\n===========\n\n")
        fp.write(f"Avg. loss: {np.mean(np.array(last_loss)):.2f}\n")
        fp.write(f"Avg. accuracy: {np.mean(np.array(last_acc)):.2f}\n")