示例#1
0
    def __init__(self,
                 main_folder_path,
                 model,
                 train_df,
                 test_df,
                 num_classes,
                 target_dim,
                 device,
                 is_colab,
                 config,
                 args_text=None):
        self.main_folder_path = main_folder_path
        self.model = model
        self.train_df = train_df
        self.test_df = test_df
        self.device = device
        self.config = config
        self.num_classes = num_classes
        self.target_dim = target_dim
        self.is_colab = is_colab
        self.args_text = args_text
        if "faster" in self.config.model_name:
            # special case of training the conventional model based on Faster R-CNN
            params = [p for p in self.model.parameters() if p.requires_grad]
            self.optimizer = self.config.optimizer_class(
                params, **self.config.optimizer_config)
        else:
            self.optimizer = self.config.optimizer_class(
                self.model.parameters(), **self.config.optimizer_config)
        self.scheduler = self.config.scheduler_class(
            self.optimizer, **self.config.scheduler_config)
        self.model_file_path = self.get_model_file_path(
            self.is_colab,
            prefix=config.model_file_prefix,
            suffix=config.model_file_suffix)
        self.log_file_path = self.get_log_file_path(
            self.is_colab, suffix=config.model_file_suffix)
        self.epoch = 0
        self.visualize = visualize.Visualize(self.main_folder_path,
                                             self.target_dim,
                                             dest_folder='Images')

        # use our dataset and defined transformations
        self.dataset = bus_dataset.BusDataset(self.main_folder_path,
                                              self.train_df, self.num_classes,
                                              self.target_dim,
                                              self.config.model_name, False,
                                              T.get_transform(train=True))
        self.dataset_test = bus_dataset.BusDataset(
            self.main_folder_path, self.test_df, self.num_classes,
            self.target_dim, self.config.model_name, False,
            T.get_transform(train=False))

        # TODO(ofekp): do we need this?
        # split the dataset in train and test set
        # indices = torch.randperm(len(dataset)).tolist()
        # dataset = torch.utils.data.Subset(dataset, indices[:-50])
        # dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

        self.log('Trainer initiallized. Device is [{}]'.format(self.device))
示例#2
0
def get_data_loaders(train_ann_file, test_ann_file, batch_size, test_size, image_size, use_mask):
    # first, crate PyTorch dataset objects, for the train and validation data.
    dataset = CocoMask(
        root=Path.joinpath(Path(train_ann_file).parent.parent, train_ann_file.split('_')[1].split('.')[0]),
        annFile=train_ann_file,
        transforms=get_transform(train=True, image_size=image_size),
        use_mask=use_mask)
    dataset_test = CocoMask(
        root=Path.joinpath(Path(test_ann_file).parent.parent, test_ann_file.split('_')[1].split('.')[0]),
        annFile=test_ann_file,
        transforms=get_transform(train=False, image_size=image_size),
        use_mask=use_mask)
    
    labels_enumeration = dataset.coco.cats
    
    indices_val = torch.randperm(len(dataset_test)).tolist()
    dataset_val = torch.utils.data.Subset(dataset_test, indices_val[:test_size])

    # set train and validation data-loaders
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=6,
                              collate_fn=safe_collate, pin_memory=True)
    val_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=6,
                            collate_fn=safe_collate, pin_memory=True)
    
    return train_loader, val_loader, labels_enumeration
示例#3
0
    def __init__(self,
                 dataset,
                 train_batch,
                 test_batch,
                 path,
                 trans_arg,
                 train_transform=None,
                 noise_file=None):
        """
        Class used to generate dataset

        :param dataset: Name of the dataset that is used
        :param train_batch: Size of a training and validation set
        :param test_batch: Size of test batch
        :param path: Path to which dataset is saved/loaded
        :param train_transform: Name of transformation used on training set (if not set, use transform based on dataset)
        :param noise_file: Name of the file containing noise (only used in the poisoned CIFAR-10 dataset )
        """
        self.dataset = dataset
        self.path = path
        self.train_batch = train_batch
        self.test_batch = test_batch

        if train_transform is None:
            self.train_transform = transforms.Compose(
                get_transform(dataset, trans_arg))
        else:
            self.train_transform = transforms.Compose(
                get_transform(train_transform, trans_arg))

        self.test_transform = transforms.ToTensor()
        self.noise_file = noise_file
        self.train_set, self.validation_set, self.test_set = self._init_datasets(
        )
示例#4
0
def search_once(config, policy):
    model = get_model(config).cuda()
    criterion = get_loss(config)
    optimizer = get_optimizer(config, model.parameters())
    scheduler = get_scheduler(config, optimizer, -1)

    transforms = {'train': get_transform(config, 'train', params={'policies': policy}),
                  'val': get_transform(config, 'val')}
    dataloaders = {split:get_dataloader(config, split, transforms[split])
                   for split in ['train', 'val']}

    score_dict = train(config, model, dataloaders, criterion, optimizer, scheduler, None, 0)
    return score_dict['f1_mavg']
示例#5
0
def get_data_loaders(image_shape, train_path, test_path, train_batch_size, test_batch_size):

    train_transform = get_transform(image_shape)
    test_transform = get_transform(image_shape)

    # define train and test datasets
    train_dataset = datasets.ImageFolder(root=train_path, transform=train_transform)
    test_dataset = datasets.ImageFolder(root=test_path, transform=test_transform)

    # define train and test loaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=test_batch_size)

    return train_loader, test_loader
示例#6
0
def run(config, num_checkpoint, epoch_end, output_filename):
    task = get_task(config)
    preprocess_opt = task.get_preprocess_opt()
    dataloader = get_dataloader(config, 'train',
                                get_transform(config, 'dev', **preprocess_opt))

    model = task.get_model()
    checkpoints = get_checkpoints(config, num_checkpoint, epoch_end)
    print('checkpoints:')
    print('\n'.join(checkpoints))

    utils.checkpoint.load_checkpoint(model, None, checkpoints[0])
    for i, checkpoint in enumerate(checkpoints[1:]):
        model2 = get_task(config).get_model()
        last_epoch, _ = utils.checkpoint.load_checkpoint(
            model2, None, checkpoint)
        swa.moving_average(model, model2, 1. / (i + 2))

    with torch.no_grad():
        swa.bn_update(dataloader, model)

    output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint,
                                        last_epoch)
    print('save {}'.format(output_name))
    utils.checkpoint.save_checkpoint(
        config,
        model,
        None,
        0,
        0,
        name=output_name,
        weights_dict={'state_dict': model.state_dict()})
示例#7
0
def get_loaders(args: Namespace) -> Tuple[DataLoader, DataLoader]:
    crop_args = {'crop_center': args.crop_center, 'crop_size': args.crop_size}
    annotation_frame = pd.read_csv(args.annotation_csv)
    test_fold, n_folds = args.dataset_split
    train_df, test_df = split_dataset(annotation_frame,
                                      test_fold,
                                      n_folds,
                                      args.data_dir,
                                      gender='a')
    # train_df = train_df.iloc[:160, :]

    data_frames = {'train': train_df, 'val': test_df}
    transforms = {
        phase: get_transform(augmentation=(phase == 'train'),
                             crop_dict=crop_args,
                             scale=args.scale)
        for phase in ['train', 'val']
    }
    datasets = {
        phase: BoneAgeDataset(bone_age_frame=data_frames[phase],
                              root=args.data_dir,
                              transform=transforms[phase],
                              target_transform=None if args.model_type
                              == 'gender' else normalize_target,
                              model_type=args.model_type)
        for phase in ['train', 'val']
    }
    dataloaders = [
        DataLoader(datasets[phase],
                   batch_size=args.batch_size,
                   shuffle=(phase == 'train'),
                   num_workers=args.n_workers) for phase in ['train', 'val']
    ]
    return dataloaders
示例#8
0
文件: test.py 项目: rosaann/landmark
def get_test_max_landmark_of_one_model(config, gi, best_model_idx, key_group):
    test_img_list = gen_test_csv()
    print('test_img_list ', len(test_img_list))

    # result_set_whole = {}
    # for img_id in test_img_list:
    #初始化添加
    #  result_set_whole[img_id] = {}
    test_data_set = get_test_loader(config, test_img_list,
                                    get_transform(config, 'val'))
    model = get_model(config, gi)
    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = get_optimizer(config, model.parameters())
    checkpoint = utils.checkpoint.get_model_saved(config, gi, best_model_idx)
    best_epoch, step = utils.checkpoint.load_checkpoint(
        model, optimizer, checkpoint)
    result_set = test_one_model(test_data_set, model, key_group)

    #
    result_list_whole = []
    for img_ps in result_set.keys():
        ps = result_set[img_ps]
        max_p_key = max(ps, key=ps.get)
        # result_set_whole[img_ps][max_p_key] = ps[max_p_key]
        result_list_whole.append((img_ps, max_p_key, ps[max_p_key]))

    test_pd = pd.DataFrame.from_records(
        result_list_whole, columns=['img_id', 'landmark_id', 'pers'])
    output_filename = os.path.join('./results/test/',
                                   'test_img_land_' + str(gi) + '.csv')
    test_pd.to_csv(output_filename, index=False)

    return
示例#9
0
def run(config):
    train_dir = config.train.dir

    task = get_task(config)
    optimizer = get_optimizer(config, task.get_model().parameters())

    checkpoint = utils.checkpoint.get_initial_checkpoint(config)
    if checkpoint is not None:
        last_epoch, step = utils.checkpoint.load_checkpoint(
            task.get_model(), optimizer, checkpoint)
    else:
        last_epoch, step = -1, -1

    print('from checkpoint: {} last epoch:{}'.format(checkpoint, last_epoch))
    scheduler = get_scheduler(config, optimizer, last_epoch)

    preprocess_opt = task.get_preprocess_opt()
    dataloaders = {
        split: get_dataloader(config, split,
                              get_transform(config, split, **preprocess_opt))
        for split in ['train', 'dev']
    }

    writer = SummaryWriter(config.train.dir)
    train(config, task, dataloaders, optimizer, scheduler, writer,
          last_epoch + 1)
示例#10
0
def run(config):
    train_dir = config.train.dir

    model = get_model(config).cuda()
    criterion = get_loss(config)
    optimizer = get_optimizer(config, model.parameters())

    checkpoint = utils.checkpoint.get_initial_checkpoint(config)
    if checkpoint is not None:
        last_epoch, step = utils.checkpoint.load_checkpoint(
            model, optimizer, checkpoint)
    else:
        last_epoch, step = -1, -1

    print('from checkpoint: {} last epoch:{}'.format(checkpoint, last_epoch))
    scheduler = get_scheduler(config, optimizer, last_epoch)

    #     dataloaders = {split:get_dataloader(config, split, get_transform(config, split))
    #                    for split in ['train', 'val']}

    print(config.data)
    dataloaders = {
        'train': get_train_dataloader(config, get_transform(config)),
        'val': get_valid_dataloaders(config)[0]
    }
    writer = SummaryWriter(train_dir)
    train(config, model, dataloaders, criterion, optimizer, scheduler, writer,
          last_epoch + 1)
示例#11
0
    def __init__(self, root_dir, split='train', debug=False):
        self.root_dir = root_dir
        self.split = split
        self.transforms = {
            'train': get_transform(True),
            'val': get_transform(False),
            'test': get_transform(False),
        }

        self.data = list()
        self.indices = None
        self.load_data()
        # 1686.2379、1354.9849为主点坐标(相对于成像平面)
        # 摄像机分辨率 3384*2710
        self.k = np.array(
            [[2304.5479, 0, 1686.2379], [0, 2305.8757, 1354.9849], [0, 0, 1]],
            dtype=np.float32)
        self.debug = debug
def inference_single_tta(config, task, preprocess_opt, split, fold, flip,
                         align, ret_dict):
    config.transform.params.align = align
    transform = 'test' if split == 'test' else 'all'
    config.data.params.landmark_ver = fold
    dataloader = get_dataloader(
        config, split,
        get_transform(config, transform, flip=flip, **preprocess_opt))
    id_dict = inference(config, task, dataloader, ret_dict)
    return id_dict
示例#13
0
def main():
    args, args_text = parse_args()
    print("Args: {}".format(args_text))

    main_folder_path = '../'
    num_classes, train_df, test_df, categories_df = train.process_data(
        main_folder_path, args.data_limit)

    dataset_test = imat_dataset.IMATDataset(main_folder_path,
                                            test_df,
                                            num_classes,
                                            args.target_dim,
                                            "effdet",
                                            False,
                                            T.get_transform(train=False),
                                            gather_statistics=False)
    h5_test_writer = DatasetH5Writer(dataset_test,
                                     args.target_dim,
                                     "../imaterialist_test_" +
                                     str(args.target_dim) + ".hdf5",
                                     chunk_size=args.chunk_size,
                                     delete_existing=args.delete_existing)
    h5_test_writer.process()
    h5_test_writer.close()

    dataset = imat_dataset.IMATDataset(main_folder_path,
                                       train_df,
                                       num_classes,
                                       args.target_dim,
                                       False,
                                       "effdet",
                                       T.get_transform(train=False),
                                       gather_statistics=False)
    h5_writer = DatasetH5Writer(dataset,
                                args.target_dim,
                                "../imaterialist_" + str(args.target_dim) +
                                ".hdf5",
                                chunk_size=args.chunk_size,
                                delete_existing=args.delete_existing)
    h5_writer.process()
    h5_writer.close()

    print("All done.")
示例#14
0
def relative_translations(path, xi, yi):
    # assumes transforms are a rotation, then a translation
    pathRobotFrame = []
    tempPath = copy.deepcopy(path)

    for j in range(len(tempPath) - 1):
        item = tempPath[j]
        nextItem = tempPath[j + 1]
        xj = item[0]
        yj = item[1]
        xk = nextItem[0]
        yk = nextItem[1]

        transTgr = tf.invert_transform(tf.get_transform(xj, yj, 0))
        transTpg = tf.get_transform(xk, yk, 0)
        transTpr = tf.chain_transforms(transTgr, transTpg)
        transPpr = tf.get_pose_vec(transTpr)

        pathRobotFrame.append(transPpr)
    return pathRobotFrame
示例#15
0
    def __init__(self, modelpath, image_size=(480,480), device="cpu", offset=64, use_test_aug=2, add_fdi_ndvi=False):
        self.image_size = image_size
        self.device = device
        self.offset = offset # remove border effects from the CNN

        #self.model = UNet(n_channels=12, n_classes=1, bilinear=False)
        self.model = get_model(os.path.basename(modelpath).split("-")[0].lower(), inchannels=12 if not add_fdi_ndvi else 14)
        #self.model = get_model(modelname, inchannels=12 if not add_fdi_ndvi else 14)
        self.model.load_state_dict(torch.load(modelpath)["model_state_dict"])
        self.model = self.model.to(device)
        self.transform = get_transform("test", add_fdi_ndvi=add_fdi_ndvi)
        self.use_test_aug = use_test_aug
示例#16
0
    def infer(test_image_data_path, test_meta_data_path):
        test_meta_data = pd.read_csv(test_meta_data_path,
                                     delimiter=',',
                                     header=0)

        tta = args.tta
        num_classes = args.num_classes
        target_size = (args.input_size, args.input_size)
        batch_size = 200
        num_workers = 4
        device = 0

        transforms = get_transform(target_size,
                                   args.test_augments,
                                   args.augment_ratio,
                                   is_train=False)
        dataset = TestDataset(test_image_data_path,
                              test_meta_data,
                              transform=transforms)
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)

        model_nsml.to(device)
        model_nsml.eval()
        total_predict = []

        for _ in range(tta):
            print("tta {} predict".format(_ + 1))
            prediction = np.zeros((len(dataloader.dataset), num_classes))

            with torch.no_grad():
                for i, image in enumerate(dataloader):
                    image = image.to(device)
                    output = model_nsml(
                        image)  # output shape (batch_num, num_classes)

                    prediction[i * batch_size:(i + 1) *
                               batch_size] = output.detach().cpu().numpy()

                    total_predict.append(prediction)
                    del prediction
                    gc.collect()

        total_predict = np.mean(total_predict, axis=0)  # mean tta predictions
        predict_vector = np.argmax(total_predict,
                                   axis=1)  # return index shape of (138343)

        return predict_vector  # this return type should be a numpy array which has shape of (138343)
示例#17
0
def assign_angles(transpath):
    currAngle = 0
    nextAngle = 0
    first = True
    angList = []
    for i in range(len(transpath) - 1):
        x = 0
        y = 0
        if (first):
            currAngle = match_angles(transpath[i])
            Tgr = tf.get_transform(0, 0, currAngle)
            first = False
        else:
            Tgr = tf.invert_transform(tf.get_transform(0, 0, currAngle))
        nextAngle = match_angles(transpath[i + 1])
        print(currAngle, nextAngle)

        Tpg = tf.get_transform(0, 0, nextAngle)
        Tpr = tf.chain_transforms(Tgr, Tpg)
        Ppr = tf.get_pose_vec(Tpr)
        theta = Ppr[2]
        angList.append(theta)
        currAngle = nextAngle
    return angList
def run(config, split, checkpoint_name, output_path):
    train_dir = config.train.dir

    task = get_task(config)
    checkpoint = utils.checkpoint.get_checkpoint(config, checkpoint_name)
    last_epoch, step = utils.checkpoint.load_checkpoint(
        task.get_model(), None, checkpoint)

    print('from checkpoint: {} last epoch:{}'.format(checkpoint, last_epoch))

    preprocess_opt = task.get_preprocess_opt()
    dataloader = get_dataloader(config, split,
                                get_transform(config, split, **preprocess_opt))

    df = inference(config, task, dataloader)
    df.to_csv(output_path, index=False)
示例#19
0
def run(config):
    teacher_model = get_model(config, 'teacher').to(device)
    student_model = get_model(config, 'student').to(device)
    print('The nubmer of parameters : %d'%count_parameters(student_model))
    criterion = get_loss(config)


    # for teacher
    optimizer_t = None
    checkpoint_t = utils.checkpoint.get_initial_checkpoint(config,
                                                         model_type='teacher')
    if checkpoint_t is not None:
        last_epoch_t, step_t = utils.checkpoint.load_checkpoint(teacher_model,
                                 optimizer_t, checkpoint_t, model_type='teacher')
    else:
        last_epoch_t, step_t = -1, -1
    print('teacher model from checkpoint: {} last epoch:{}'.format(
        checkpoint_t, last_epoch_t))

    # for student
    optimizer_s = get_optimizer(config, student_model)
    checkpoint_s = utils.checkpoint.get_initial_checkpoint(config,
                                                         model_type='student')
    if checkpoint_s is not None:
        last_epoch_s, step_s = utils.checkpoint.load_checkpoint(student_model,
                                 optimizer_s, checkpoint_s, model_type='student')
    else:
        last_epoch_s, step_s = -1, -1
    print('student model from checkpoint: {} last epoch:{}'.format(
        checkpoint_s, last_epoch_s))

    scheduler_s = get_scheduler(config, optimizer_s, last_epoch_s)

    print(config.data)
    dataloaders = {'train':get_train_dataloader(config, get_transform(config)),
                   'val':get_valid_dataloader(config)}
                   #'test':get_test_dataloader(config)}
    writer = SummaryWriter(config.train['student' + '_dir'])
    visualizer = get_visualizer(config)
    result = train(config, student_model, teacher_model, dataloaders,
          criterion, optimizer_s, scheduler_s, writer,
          visualizer, last_epoch_s+1)
    
    print('best psnr : %.3f, best epoch: %d'%(result['best_psnr'], result['best_epoch']))
示例#20
0
文件: train.py 项目: rosaann/landmark
def run(config):
    train_group_csv_dir = './data/group_csv/'
    writer = SummaryWriter(config.train.dir)
    train_filenames = list(glob.glob(os.path.join(train_group_csv_dir, 'data_train_group_*')))[1:]
    
    for ti, train_file in tqdm.tqdm(enumerate(train_filenames)):
        gi_tr = train_file.replace('data_train_group_', '')
        gi_tr = gi_tr.split('/')[-1]
        gi_tr = gi_tr.replace('.csv', '')
        group_idx = int(gi_tr)
        
        utils.prepare_train_directories(config, group_idx)
        
        model = get_model(config, group_idx)
        if torch.cuda.is_available():
            model = model.cuda()
        criterion = get_loss(config)
        optimizer = get_optimizer(config, model.parameters())
        
    

        checkpoint = utils.checkpoint.get_initial_checkpoint(config, group_idx)
        if checkpoint is not None:
            last_epoch, step = utils.checkpoint.load_checkpoint(model, optimizer, checkpoint)
        else:
            last_epoch, step = -1, -1

        if last_epoch > config.train.num_epochs:
            print('group -- ', str(group_idx), '-- index-', ti, '  ----已xl,跳过')
            continue
        print('from checkpoint: {} last epoch:{}'.format(checkpoint, last_epoch))
        print('group -- ', str(group_idx), '-- index-', ti)
        scheduler = get_scheduler(config, optimizer, last_epoch)
    
        dataloaders = {split:get_dataloader(config, group_idx, split, get_transform(config, split))
                   for split in ['train', 'val']}
    

    
        train(config,group_idx, model, dataloaders, criterion, optimizer, scheduler,
          writer, last_epoch+1)
示例#21
0
def run(config, num_checkpoint, epoch_end, output_filename):
    dataloader = get_dataloader(config, 'train', get_transform(config, 'val'))

    model = get_model(config)
    if torch.cuda.is_available():
        model = model.cuda()
    checkpoints = get_checkpoints(config, num_checkpoint, epoch_end)

    utils.checkpoint.load_checkpoint(model, None, checkpoints[0])
    for i, checkpoint in enumerate(checkpoints[1:]):
        model2 = get_model(config)
        if torch.cuda.is_available():
            model2 = model2.cuda()
        last_epoch, _ = utils.checkpoint.load_checkpoint(model2, None, checkpoint)
        swa.moving_average(model, model2, 1. / (i + 2))

    with torch.no_grad():
        swa.bn_update(dataloader, model)

    output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint, last_epoch)
    print('save {}'.format(output_name))
    utils.checkpoint.save_checkpoint(config, model, None, 0, 0,
                                     name=output_name,
                                     weights_dict={'state_dict': model.state_dict()})
示例#22
0
            train_df = df
            valid_df = df
        else:
            NROWS = None
            meta_df = pd.read_csv(meta_path,
                                  delimiter=',',
                                  header=0,
                                  nrows=NROWS)
            df = make_validation(meta_df, label_path)

            train_df = df.loc[df['fold'] != args.fold_num].reset_index(
                drop=True)
            valid_df = df.loc[df['fold'] == args.fold_num].reset_index(
                drop=True)

        train_transforms = get_transform(target_size, args.train_augments,
                                         args.augment_ratio)
        valid_transforms = get_transform(target_size,
                                         args.test_augments,
                                         args.augment_ratio,
                                         is_train=False)
        train_loader = make_loader(train_df, image_dir, train_transforms,
                                   args.batch_size, args.num_workers)
        valid_loader = make_loader(valid_df, image_dir, valid_transforms,
                                   args.batch_size, args.num_workers)
        print("train batches: {}".format(len(train_loader)))
        print("valid batches: {}".format(len(valid_loader)))
        print("batch_size: {}".format(args.batch_size))
        print()

        best_val_acc = 0
        grad_clip_step = 100
示例#23
0
                        help="Single phrase describing experiment.")
    parser.add_argument("--debug",
                        action="store_true",
                        help="Fast dev run mode.")
    parser.add_argument(
        "--cli_args",
        type=str,
        default=str(argv),
        help="Store command line arguments. Don't change manually.")
    parser.add_argument("--no_fit",
                        action="store_true",
                        help="Do everything except starting the fit.")
    parser.add_argument("--cpu", action="store_true", help="Force using CPU.")
    hparams = parser.parse_args()

    hparams.train_transform, hparams.eval_transform = get_transform(
        hparams.transform_str)

    print("Hyperparameters")
    for k, v in vars(hparams).items():
        print(f"{k}: {v}")

    if hparams.task == "train":
        model = CNNT5(hparams=hparams)

        if hparams.debug:
            logger = False
            callbacks = None
        else:
            logger = NeptuneLogger(api_key=os.getenv('NEPTUNE_API_TOKEN'),
                                   project_name="dscarmo/layoutlmt5",
                                   experiment_name=hparams.experiment_name,
示例#24
0
    def __init__(self, args):

        # Training configurations
        self.method = args.method
        self.dataset = args.dataset
        self.dim = args.dim
        self.lr = args.lr
        self.batch_size = args.batch_size
        self.val_batch_size = self.batch_size // 2
        self.iteration = args.iteration
        self.evaluation = args.evaluation
        self.show_iter = 1000
        self.update_epoch = 10
        self.balanced = args.balanced
        self.instances = args.instances
        self.cm = args.cm
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        self.file_name = '{}_{}_{}'.format(
            self.method,
            self.dataset,
            self.lr,
        )
        print('========================================')
        print(json.dumps(vars(args), indent=2))
        print(self.file_name)

        # Paths

        self.root_dir = os.path.join('/', 'home', 'lyz')
        self.data_dir = os.path.join(self.root_dir, 'datasets', self.dataset)
        self.model_dir = self._get_path('./trained_model')
        self.code_dir = self._get_path(os.path.join('codes', self.dataset))
        self.fig_dir = self._get_path(
            os.path.join('fig', self.dataset, self.file_name))

        # Preparing data
        self.transforms = get_transform()
        self.datasets = get_datasets(dataset=self.dataset,
                                     data_dir=self.data_dir,
                                     transforms=self.transforms)
        self.cm_sampler = ClassMiningSampler(self.datasets['train'],
                                             batch_size=self.batch_size,
                                             n_instance=self.instances,
                                             balanced=self.balanced)
        self.data_loaders = get_data_loaders(
            datasets=self.datasets,
            batch_size=self.batch_size,
            val_batch_size=self.val_batch_size,
            n_instance=self.instances,
            balanced=self.balanced,
            cm=self.cm_sampler if self.cm else None)
        self.dataset_sizes = {
            x: len(self.datasets[x])
            for x in ['train', 'test']
        }

        # Set up model
        self.model = get_model(self.device, self.dim)

        self.optimizer = optim.SGD(
            [{
                'params': self.model.google_net.parameters()
            }, {
                'params': self.model.linear.parameters(),
                'lr': self.lr * 10,
                'momentum': 0.9
            }],
            lr=self.lr,
            momentum=0.9)
        self.scheduler = lr_scheduler.StepLR(self.optimizer,
                                             step_size=2000,
                                             gamma=0.5)
示例#25
0
def main():
    # init the args
    global best_pred, acclist_train, acclist_val
    args = Options().parse()

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    _, _, transform_infer = transforms.get_transform(args.dataset)
    galleryset = datasets.get_dataset(args.dataset,
                                      root='/home/ace19/dl_data/materials/train',
                                      transform=transform_infer)
    queryset = datasets.get_dataset(args.dataset,
                                    split='eval',
                                    root='/home/ace19/dl_data/materials/query',
                                    transform=transform_infer)
    gallery_loader = DataLoader(
        galleryset, batch_size=args.batch_size, num_workers=args.workers)
    query_loader = torch.utils.data.DataLoader(
        queryset, batch_size=args.test_batch_size, num_workers=args.workers)

    # init the model
    model = model_zoo.get_model(args.model)
    print(model)

    if args.cuda:
        model.cuda()
        # Please use CUDA_VISIBLE_DEVICES to control the number of gpus
        model = nn.DataParallel(model)

    # check point
    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            print("=> loading checkpoint '{}'".format(args.checkpoint))
            checkpoint = torch.load(args.checkpoint)
            args.start_epoch = checkpoint['epoch'] + 1
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.checkpoint, checkpoint['epoch']))
        else:
            raise RuntimeError("=> no infer checkpoint found at '{}'". \
                               format(args.checkpoint))
    else:
        raise RuntimeError("=> config \'args.checkpoint\' is '{}'". \
                           format(args.checkpoint))


    gallery_features_list = []
    gallery_path_list = []
    query_features_list = []
    query_path_list = []
    def retrieval():
        model.eval()

        print(" ==> Loading gallery ... ")
        tbar = tqdm(gallery_loader, desc='\r')
        for batch_idx, (gallery_paths, data, gt) in enumerate(tbar):
            if args.cuda:
                data, gt = data.cuda(), gt.cuda()

            with torch.no_grad():
                # features [256, 2048]
                # output [256, 128]
                # features, output = model(data)

                # TTA
                batch_size, n_crops, c, h, w = data.size()
                # fuse batch size and ncrops
                features, _ = model(data.view(-1, c, h, w))
                # avg over crops
                features = features.view(batch_size, n_crops, -1).mean(1)
                gallery_features_list.extend(features)
                gallery_path_list.extend(gallery_paths)
        # end of for

        print(" ==> Loading query ... ")
        tbar = tqdm(query_loader, desc='\r')
        for batch_idx, (query_paths, data) in enumerate(tbar):
            if args.cuda:
                data = data.cuda()

            with torch.no_grad():
                # TTA
                batch_size, n_crops, c, h, w = data.size()
                # fuse batch size and ncrops
                features, _ = model(data.view(-1, c, h, w))
                # avg over crops
                features = features.view(batch_size, n_crops, -1).mean(1)
                query_features_list.extend(features)
                query_path_list.extend(query_paths)
        # end of for

        if len(query_features_list) == 0:
            print('No query data!!')
            return

        # matching
        top_n_indice, top_n_distance = \
            match_n(TOP_N,
                    torch.stack(gallery_features_list).cpu(),
                    torch.stack(query_features_list).cpu())

        # Show n images from the gallery similar to the query image.
        show_retrieval_result(top_n_indice, top_n_distance, gallery_path_list, query_path_list)

    retrieval()
示例#26
0
文件: collection.py 项目: GetmeUK/h51
    def validate_variations(self, asset_type, variations):
        """
        Validate and the given map of variations (if valid the variations are
        returned.
        """

        # Check the structure of the variations is valid
        if not isinstance(variations, dict):
            raise APIError('invalid_request',
                           hint='Request body JSON must be an object.')

        if len(variations) == 0:
            raise APIError('invalid_request',
                           hint='At least one variation is required.')

        elif len(variations) > self.config['MAX_VARIATIONS_PER_REQUEST']:
            raise APIError(
                'invalid_request',
                hint=('The maximum number of variations that can be added in '
                      'single request is '
                      f"{self.config['MAX_VARIATIONS_PER_REQUEST']}."))

        for name, transforms in variations.items():

            # Check the name of the variation is valid
            slug = slugify(
                name,
                regex_pattern=ALLOWED_SLUGIFY_CHARACTERS,
            )

            # Unlike slugify we allow dashes at the start/end of the variation
            # name, so we strip dashes before the test.
            if slug != name.strip('-'):
                raise APIError('invalid_request',
                               hint=f'Not a valid variation name: {name}.')

            # Check the required number of transforms have been provided
            if len(transforms) == 0:
                raise APIError(
                    'invalid_request',
                    hint=('At least one transform per variation is required: '
                          f'{name}.'))

            for i, transform in enumerate(transforms):

                # Check transform structure
                if not (len(transform) == 2 and isinstance(transform[0], str)
                        and isinstance(transform[1], dict)):
                    raise APIError(
                        'invalid_request',
                        hint=(f'Invalid transform structure: {transform} '
                              f'({name}).'))

                # Check the transform exists
                transform_cls = get_transform(asset_type, transform[0])
                if not transform_cls:
                    raise APIError(
                        'invalid_request',
                        hint=(
                            f'Unknown transform: {asset_type}:{transform[0]} '
                            f'({name}).'))

                # Check only the last transform in the list is flagged as a
                # final transform.
                if transform_cls.final and i < len(transforms) - 1:
                    raise APIError(
                        'invalid_request',
                        hint=('Final transform not set as last transform: '
                              f'{asset_type}:{transform[0]} ({name}).'))

                if not transform_cls.final and i == len(transforms) - 1:
                    raise APIError(
                        'invalid_request',
                        hint=(f'Last transform in list is not final: {name}'))

                # Check the settings for the transform are correct
                settings_form = transform_cls.get_settings_form_cls()(
                    MultiDict({
                        k: v
                        for k, v in transform[1].items() if v is not None
                    }))
                if not settings_form.validate():
                    raise APIError(
                        'invalid_request',
                        hint=('Invalid settings for transform: '
                              f'{asset_type}:{transform[0]} ({name}).'),
                        arg_errors=settings_form.errors)

        return variations
示例#27
0
def main():
    # init the args
    global best_pred, acclist_train, acclist_val
    args = Options().parse()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # init dataloader
    transform_train, transform_val, _ = transforms.get_transform(args.dataset)
    trainset = datasets.get_dataset(args.dataset,
                                    root='/home/ace19/dl_data/materials/train',
                                    transform=transform_train)
    valset = datasets.get_dataset(
        args.dataset,
        root='/home/ace19/dl_data/materials/validation',
        transform=transform_val)

    # balanced sampling between classes
    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              sampler=ImbalancedDatasetSampler(trainset))
    # train_loader = DataLoader(
    #     trainset, batch_size=args.batch_size, shuffle=True,
    #     num_workers=args.workers, pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    # init the backbone model
    if args.pretrained is not None:
        model = model_zoo.get_model(args.model, backbone=args.backbone)
    else:
        model = model_zoo.get_model(args.model,
                                    backbone_pretrained=True,
                                    backbone=args.backbone)
    print(model)

    # criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
    #                             weight_decay=args.weight_decay)
    if args.cuda:
        model.cuda()
        criterion.cuda()
        # Please use CUDA_VISIBLE_DEVICES to control the number of gpus
        model = nn.DataParallel(model)
    # check point
    if args.pretrained is not None:
        if os.path.isfile(args.pretrained):
            print("=> loading checkpoint '{}'".format(args.pretrained))
            checkpoint = torch.load(args.pretrained)
            args.start_epoch = checkpoint['epoch'] + 1
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.pretrained, checkpoint['epoch']))
        else:
            raise RuntimeError("=> no pretrained checkpoint found at '{}'". \
                               format(args.pretrained))

    scheduler = lr_scheduler.LR_Scheduler(args.lr_scheduler,
                                          args.lr, args.epochs,
                                          len(train_loader), args.lr_step)

    def train(epoch):
        model.train()
        losses = AverageMeter()
        top1 = AverageMeter()

        global best_pred, acclist_train, acclist_val

        tbar = tqdm(train_loader, desc='\r')
        for batch_idx, (_, images, targets) in enumerate(tbar):
            scheduler(optimizer, batch_idx, epoch, best_pred)
            # display_data(images)
            # TODO: Convert from list of 3D to 4D
            # images = np.stack(images, axis=1)
            # images = torch.from_numpy(images)

            if args.cuda:
                images, targets = images.cuda(), targets.cuda()

            # compute gradient and do SGD step
            optimizer.zero_grad()
            _, output = model(images)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()

            acc1 = accuracy(output, targets)
            top1.update(acc1[0], images.size(0))
            losses.update(loss.item(), images.size(0))
            tbar.set_description('\rLoss: %.3f | Top1: %.3f' %
                                 (losses.avg, top1.avg))

        acclist_train += [top1.avg]

    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        top5 = AverageMeter()
        confusion_matrix = torch.zeros(args.nclass, args.nclass)

        global best_pred, acclist_train, acclist_val
        is_best = False

        tbar = tqdm(val_loader, desc='\r')
        # TTA(TenCrop) input, target = batch # input is a 5d tensor, target is 2d
        # bs, ncrops, c, h, w = input.size()
        # result = model(input.view(-1, c, h, w))  # fuse batch size and ncrops
        # result_avg = result.view(bs, ncrops, -1).mean(1)  # avg over crops
        for batch_idx, (fnames, images, targets) in enumerate(tbar):
            # Convert from list of 3D to 4D
            # images = np.stack(images, axis=1)
            # images = torch.from_numpy(images)

            if args.cuda:
                images, targets = images.cuda(), targets.cuda()
                # images, targets = Variable(images), Variable(targets)
            with torch.no_grad():
                # _, output = model(images)

                # TTA
                batch_size, n_crops, c, h, w = images.size()
                # fuse batch size and ncrops
                _, output = model(images.view(-1, c, h, w))
                # avg over crops
                output = output.view(batch_size, n_crops, -1).mean(1)
                # accuracy
                acc1, acc5 = accuracy(output, targets, topk=(1, 1))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # confusion matrix
                _, preds = torch.max(output, 1)
                for t, p in zip(targets.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1

            tbar.set_description('Top1: %.3f | Top5: %.3f' %
                                 (top1.avg, top5.avg))
        # end of for

        print('\n----------------------------------')
        print('confusion matrix:\n', confusion_matrix)
        # get the per-class accuracy
        print('\nper-class accuracy(precision):\n',
              confusion_matrix.diag() / confusion_matrix.sum(1))
        print('----------------------------------\n')

        if args.eval:
            print('Top1 Acc: %.3f | Top5 Acc: %.3f ' % (top1.avg, top5.avg))
            return

        # save checkpoint
        acclist_val += [top1.avg]
        if top1.avg > best_pred:
            best_pred = top1.avg
            is_best = True
        files.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            },
            args=args,
            is_best=is_best)

    if args.eval:
        validate(args.start_epoch)
        # writer.close()
        return

    for epoch in range(args.start_epoch, args.epochs + 1):
        train(epoch)
        validate(epoch)
示例#28
0
def main():
    args = parser.parse_args()

    # define args more
    args.train_meta = './meta/CARS196/train.txt'
    args.test_meta = './meta/CARS196/test.txt'

    args.lr_decay_epochs = [
        int(epoch) for epoch in args.lr_decay_epochs.split(',')
    ]
    args.recallk = [int(k) for k in args.recallk.split(',')]

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_idx)
    args.ctx = [mx.gpu(0)]

    print(args)

    # Set random seed
    mx.random.seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # Load image transform
    train_transform, test_transform = T.get_transform(
        image_size=args.image_size)

    # Load data loader
    train_loader, test_loader = D.get_data_loader(
        args.data_dir, args.train_meta, args.test_meta, train_transform,
        test_transform, args.batch_size, args.num_instances, args.num_workers)

    # Load model
    model = Model(args.embed_dim, args.ctx)
    model.hybridize()

    # Load loss
    loss = HPHNTripletLoss(margin=args.margin,
                           soft_margin=False,
                           num_instances=args.num_instances,
                           n_inner_pts=args.n_inner_pts,
                           l2_norm=args.ee_l2norm)

    # Load logger and saver
    summary_writer = SummaryWriter(
        os.path.join(args.save_dir, 'tensorboard_log'))

    print("steps in epoch:", args.lr_decay_epochs)
    steps = list(map(lambda x: x * len(train_loader), args.lr_decay_epochs))
    print("steps in iter:", steps)
    lr_schedule = mx.lr_scheduler.MultiFactorScheduler(
        step=steps, factor=args.lr_decay_factor)
    lr_schedule.base_lr = args.lr

    # Load optimizer for training
    optimizer = mx.gluon.Trainer(model.collect_params(),
                                 'adam', {
                                     'learning_rate': args.lr,
                                     'wd': args.wd
                                 },
                                 kvstore=args.kvstore)

    # Load trainer & evaluator
    trainer = Trainer(model,
                      loss,
                      optimizer,
                      train_loader,
                      summary_writer,
                      args.ctx,
                      summary_step=args.summary_step,
                      lr_schedule=lr_schedule)

    evaluator = Evaluator(model, test_loader, args.ctx)

    best_metrics = [0.0]  # all query

    global_step = args.start_epoch * len(train_loader)

    # Enter to training loop
    print("base lr mult:", args.base_lr_mult)
    for epoch in range(args.start_epoch, args.epochs):
        model.backbone.collect_params().setattr('lr_mult', args.base_lr_mult)

        trainer.train(epoch)
        global_step = (epoch + 1) * len(train_loader)
        if (epoch + 1) % args.eval_epoch_term == 0:
            old_best_metric = best_metrics[0]
            # evaluate_and_log(summary_writer, evaluator, ranks, step, epoch, best_metrics)
            best_metrics = evaluate_and_log(summary_writer,
                                            evaluator,
                                            args.recallk,
                                            global_step,
                                            epoch + 1,
                                            best_metrics=best_metrics)
            if best_metrics[0] != old_best_metric:
                save_path = os.path.join(
                    args.save_dir, 'model_epoch_%05d.params' % (epoch + 1))
                model.save_parameters(save_path)
        sys.stdout.flush()
def run(task_args):
    writer = SummaryWriter(log_dir=task_args.log_dir)
    input_checkpoint = torch.load(task_args.input_checkpoint)
    labels_enum = input_checkpoint.get('labels_enumeration')
    model_configuration = input_checkpoint.get('configuration')
    model_weights = input_checkpoint.get('model')
    image_size = model_configuration.get('image_size')

    # Set the training device to GPU if available - if not set it to CPU
    device = torch.cuda.current_device() if torch.cuda.is_available(
    ) else torch.device('cpu')
    torch.backends.cudnn.benchmark = True if torch.cuda.is_available(
    ) else False  # optimization for fixed input size

    # Get the relevant model based in task arguments
    num_classes = model_configuration.get('num_classes')
    if model_configuration.get('model_type') == 'maskrcnn':
        model = get_model_instance_segmentation(
            num_classes,
            model_configuration.get('mask_predictor_hidden_layer'))
    elif model_configuration.get('model_type') == 'ssd':
        backbone = get_backbone(model_configuration.get('ssd_backbone'))
        model = SSD(backbone=backbone,
                    num_classes=num_classes,
                    loss_function=SSDLoss(num_classes))
        model.dry_run(
            torch.rand(size=(1, 3, model_configuration.get('image_size'),
                             model_configuration.get('image_size'))) * 255)
    else:
        raise ValueError(
            'Only "maskrcnn" and "ssd" are supported as model type')

    # if there is more than one GPU, parallelize the model
    if torch.cuda.device_count() > 1:
        print("{} GPUs were detected - we will use all of them".format(
            torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    # copy the model to each device
    model.to(device)

    # Define train and test datasets
    iou_types = get_iou_types(model)
    use_mask = True if "segm" in iou_types else False

    # Load pretrained model weights
    model.load_state_dict(model_weights)

    # set the model to inference mode
    model.eval()

    images_paths = []
    for file_type in ('*.png', '*.jpg', '*.jpeg'):
        images_paths.extend(
            glob.glob(os.path.join(task_args.input_dataset_root, file_type)))

    transforms = get_transform(train=False, image_size=image_size)

    path_to_json = os.path.join(task_args.output_dir, "inference_results.json")
    coco_like_anns = CocoLikeAnnotations()
    batch_images = []
    batch_paths = []
    batch_shapes = []

    for i, image_path in enumerate(images_paths):
        img = Image.open(image_path).convert('RGB')
        batch_shapes.append({'height': img.height, 'width': img.width})
        img, __ = transforms(img)
        batch_images.append(img)
        batch_paths.append(image_path)
        if len(batch_images) < task_args.batch_size:
            continue

        input_images = torch.stack(batch_images)

        with torch.no_grad():
            torch_out = model(input_images.to(device))

        for img_num, image in enumerate(input_images):
            valid_detections = torch_out[img_num].get(
                'scores') >= args.detection_thresh
            img_boxes = torch_out[img_num].get(
                'boxes')[valid_detections].cpu().numpy()
            img_labels_ids = torch_out[img_num].get(
                'labels')[valid_detections].cpu().numpy()
            img_labels = [
                labels_enum[label]['name'] for label in img_labels_ids
            ]
            image_id = (i + 1 - task_args.batch_size + img_num)
            orig_height = batch_shapes[img_num].get('height')
            orig_width = batch_shapes[img_num].get('width')

            coco_like_anns.update_images(file_name=Path(
                batch_paths[img_num]).name,
                                         height=orig_height,
                                         width=orig_width,
                                         id=image_id)

            for box, label, label_id in zip(img_boxes, img_labels,
                                            img_labels_ids):
                orig_box = rescale_box(image_size=image_size,
                                       orig_height=orig_height,
                                       orig_width=orig_width,
                                       box=box.copy())
                coco_like_anns.update_annotations(box=orig_box,
                                                  label_id=label_id,
                                                  image_id=image_id)

            if ((i + 1) / task_args.batch_size) % task_args.log_interval == 0:
                print('Batch {}: Saving detections of file {} to {}'.format(
                    int((i + 1) / task_args.batch_size),
                    Path(batch_paths[img_num]).name, path_to_json))

            if ((i + 1) / task_args.batch_size
                ) % task_args.debug_images_interval == 0:
                debug_image = draw_boxes(np.array(F.to_pil_image(image.cpu())),
                                         img_boxes,
                                         img_labels,
                                         color=(0, 150, 0))
                writer.add_image("inference/image_{}".format(img_num),
                                 debug_image, ((i + 1) / task_args.batch_size),
                                 dataformats='HWC')

        batch_images = []
        batch_paths = []

    coco_like_anns.dump_to_json(path_to_json=path_to_json)
示例#30
0
    def __init__(self, args):

        # Training configurations
        self.method = args.method
        self.dataset = args.dataset
        self.dim = args.dim
        self.lr_init = args.lr_init
        self.gamma_m = args.gamma_m
        self.gamma_s = args.gamma_s
        self.batch_size = args.batch_size
        self.val_batch_size = self.batch_size // 2
        self.iteration = args.iteration
        self.evaluation = args.evaluation
        self.show_iter = 1000
        self.update_epoch = args.update_epoch
        self.balanced = args.balanced
        self.instances = args.instances
        self.inter_test = args.intertest
        self.cm = args.cm
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.n_class = args.batch_size // args.instances
        self.classes = args.classes
        self.pretrained = args.pretrained
        self.model_save_interval = args.model_save_interval


        self.file_name = '{}_{}_{}'.format(
            self.method,
            self.dataset,
            self.iteration,
        )
        print('========================================')
        print(json.dumps(vars(args), indent=2))
        print(self.file_name)

        # Paths

        self.root_dir = os.path.join('/', 'data')
        self.data_dir = os.path.join(self.root_dir, self.dataset)
        self.model_dir = self._get_path('./trained_model')
        self.plot_dir = self._get_path('./plot_model')
        self.code_dir = self._get_path(os.path.join('codes', self.dataset))
        self.fig_dir = self._get_path(os.path.join('fig', self.dataset, self.file_name))

        # Preparing data
        self.transforms = get_transform()
        self.datasets = get_datasets(dataset=self.dataset, data_dir=self.data_dir, transforms=self.transforms)

        self.data_loaders = get_data_loaders(
            datasets=self.datasets,
            batch_size=self.batch_size,
            val_batch_size=self.val_batch_size,
            n_instance=self.instances,
            balanced=self.balanced,
            #cm=self.cm_sampler if self.cm else None
        )
        self.dataset_sizes = {x: len(self.datasets[x]) for x in ['train', 'test']}


        self.mean = (torch.zeros((self.classes,self.classes)).add(1.5)-1.0*torch.eye(self.classes)).to(self.device)
        self.std = (torch.zeros((self.classes,self.classes)).add(0.15)).to(self.device)
        self.last_delta_mean = torch.zeros((self.classes,self.classes)).to(self.device)
        self.last_delta_std = torch.zeros((self.classes,self.classes)).to(self.device)

        
        self.ndmodel = nd.NDfdml(n_class=self.n_class,batch_size=self.batch_size,instances=self.instances,pretrained=self.pretrained).to(self.device)
        
        
        optimizer_c = optim.SGD(
            [
                {'params': self.ndmodel.googlelayer.parameters()},
                {'params': self.ndmodel.embedding_layer.parameters(), 'lr': self.lr_init * 10, 'momentum': 0.9}
            ],
            lr=self.lr_init, momentum=0.9
        )


        self.scheduler = lr_scheduler.StepLR(optimizer_c, step_size=4000, gamma=0.9)