Exemplo n.º 1
0
def run(config):
    # train_dir = config.train.dir

    # model_segmenter = get_model(config.model_segmenter.name)
    model_segmenter = XResNet2()  #NetX2()#LinkNet(1)
    if torch.cuda.is_available():
        model_segmenter = model_segmenter.cuda()
    criterion_segmenter = get_loss(config.loss_segmenter)
    optimizer_segmenter = get_optimizer(config.optimizer_segmenter.name,
                                        model_segmenter.parameters(),
                                        config.optimizer_segmenter.params)

    ####
    checkpoint_segmenter = get_initial_checkpoint(config.train_segmenter.dir)
    if checkpoint_segmenter is not None:
        last_epoch, step = load_checkpoint(model_segmenter,
                                           optimizer_segmenter,
                                           checkpoint_segmenter)
    else:
        last_epoch, step = -1, -1

    print('from segmenter checkpoint: {} last epoch:{}'.format(
        checkpoint_segmenter, last_epoch))
    #  scheduler = get_scheduler(config, optimizer, last_epoch)
    print('config.train ', config.train)
    writer = SummaryWriter(config.train.writer_dir)

    scheduler = 'none'
    #  train_classifier_dataloaders = get_dataloader(config.data_classifier, './data/data_train.csv',config.train_classifier.batch_size, 'train',config.transform_classifier.num_preprocessor, get_transform(config.transform_classifier, 'train'))
    #  eval_classifier_dataloaders = get_dataloader(config.data_classifier, './data/data_val.csv',config.eval_classifier.batch_size, 'val', config.transform_classifier.num_preprocessor, get_transform(config.transform_classifier, 'val'))
    #  test_dataloaders = get_dataloader(config.data_classifier,'./data/data_test.csv', get_transform(config, 'test'))

    #  train_classifier(config, model_classifier, train_classifier_dataloaders,eval_classifier_dataloaders, criterion_classifier, optimizer_classifier, scheduler,
    #        writer, last_epoch+1)

    criterion_segmenter = nn.MSELoss()

    train_segmenter_dataloaders = get_dataloader(
        config.train_segmenter.batch_size, 'train')

    eval_segmenter_dataloaders = get_dataloader(
        config.train_segmenter.batch_size, 'val')

    train_segmenter(config, model_segmenter, train_segmenter_dataloaders,
                    eval_segmenter_dataloaders, criterion_segmenter,
                    optimizer_segmenter, scheduler, writer, last_epoch + 1)
Exemplo n.º 2
0
def main():
    args = parse_args()
    if args.config_file is None:
        raise Exception('no configuration file')

    config = utils.config.load(args.config_file)

    model_segmenter = XResNet()
    if torch.cuda.is_available():
        model_segmenter = model_segmenter.cuda()

    optimizer_segmenter = get_optimizer(config.optimizer_segmenter.name,
                                        model_segmenter.parameters(),
                                        config.optimizer_segmenter.params)
    ####
    checkpoint = get_model_saved(config.train_segmenter.dir, 4809)
    best_epoch, step = load_checkpoint(model_segmenter, optimizer_segmenter,
                                       checkpoint)

    test_segmenter_dataloaders = get_test_dataloader(10)

    test_segmenter(config, model_segmenter, test_segmenter_dataloaders)
Exemplo n.º 3
0
def run(config_file):
    config = load_config(config_file)
    config.work_dir = '/home/koga/workspace/kaggle_bengali/result/' + config.work_dir
    os.makedirs(config.work_dir, exist_ok=True)
    os.makedirs(config.work_dir + "/checkpoints", exist_ok=True)
    print('working directory:', config.work_dir)
    logger = get_logger(config.work_dir + "log.txt")

    all_transforms = {}
    all_transforms['train'] = Transform(
        size=config.data.image_size,
        affine=config.transforms.affine,
        autoaugment_ratio=config.transforms.autoaugment_ratio,
        threshold=config.transforms.threshold,
        sigma=config.transforms.sigma,
        blur_ratio=config.transforms.blur_ratio,
        noise_ratio=config.transforms.noise_ratio,
        cutout_ratio=config.transforms.cutout_ratio,
        grid_distortion_ratio=config.transforms.grid_distortion_ratio,
        random_brightness_ratio=config.transforms.random_brightness_ratio,
        piece_affine_ratio=config.transforms.piece_affine_ratio,
        ssr_ratio=config.transforms.ssr_ratio,
        grid_mask_ratio=config.transforms.grid_mask_ratio,
        augmix_ratio=config.transforms.augmix_ratio,
    )
    all_transforms['valid'] = Transform(size=config.data.image_size)

    dataloaders = {
        phase: make_loader(
            phase=phase,
            df_path=config.train.dfpath,
            batch_size=config.train.batch_size,
            num_workers=config.num_workers,
            idx_fold=config.data.params.idx,
            fold_csv=config.data.params.fold_csv,
            transforms=all_transforms[phase],
            # debug=config.debug
            crop=config.transforms.crop)
        for phase in ['train', 'valid']
    }
    model = MODEL_LIST[config.model.version](back_bone=config.model.back_bone,
                                             out_dim=1295)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
        config.optimizer.params.lr *= torch.cuda.device_count()
        torch.backends.cudnn.benchmark = True

    model = model.to(device)

    criterion = get_criterion(config)
    optimizer = get_optimizer(config, model)
    scheduler = get_scheduler(optimizer, config)

    accumlate_step = 1
    if config.train.accumulation_size > 0:
        accumlate_step = config.train.accumulation_size // config.train.batch_size

    best_valid_recall = 0.0
    if config.train.resume:
        print('resume checkpoints')
        checkpoint = torch.load("/home/koga/workspace/kaggle_bengali/result/" +
                                config.train.path)
        model.load_state_dict(fix_model_state_dict(checkpoint['checkpoint']))
        # model.load_state_dict(checkpoint['checkpoint'])
        # best_valid_recall = checkpoint['best_valid_recall']

    # if config.train.earlyStopping:
    #     early_stopping = EarlyStopping(patience=patience, verbose=True)

    valid_recall = 0.0

    for epoch in range(1, config.train.num_epochs + 1):
        print(f'epoch {epoch} start')
        logger.info(f'epoch {epoch} start ')

        metric_train = do_train(model, dataloaders["train"], criterion,
                                optimizer, device, config, epoch,
                                accumlate_step)
        torch.cuda.empty_cache()
        metrics_eval = do_eval(model, dataloaders["valid"], criterion, device)
        torch.cuda.empty_cache()
        valid_recall = metrics_eval["valid_metric"]

        scheduler.step(metrics_eval["valid_loss"])

        print(f'epoch: {epoch} ', metric_train, metrics_eval)
        logger.info(f'epoch: {epoch} {metric_train} {metrics_eval}')
        if valid_recall > best_valid_recall:
            print(f"save checkpoint: best_recall:{valid_recall}")
            logger.info(f"save checkpoint: best_recall:{valid_recall}")
            torch.save(
                {
                    'checkpoint': model.state_dict(),
                    'epoch': epoch,
                    'best_valid_recall': valid_recall,
                }, config.work_dir + "/checkpoints/" + f"{epoch}.pth")
            best_valid_recall = valid_recall

        torch.cuda.empty_cache()
        gc.collect()
Exemplo n.º 4
0
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = get_args()
    cfg = Config.fromfile(args.config)
    cfg.device = device

    train = pd.read_csv(cfg.train_csv)
    camera_matrix_inv = np.linalg.inv(kaggle.camera_matrix)

    if 0:
        points_df = pd.DataFrame()
        for col in ['x', 'y', 'z', 'yaw', 'pitch', 'roll']:
            arr = []
            for ps in train['PredictionString']:
                coords = kaggle.str2coords(ps)
                arr += [c[col] for c in coords]
            points_df[col] = arr

        log.info(f'len(points_df): {len(points_df)}')
        log.info(points_df.head())

        img = imread(opj(cfg.train_images, train.iloc[0]['ImageId'] + '.jpg'))
        # plt.figure(figsize=(15,8))
        # plt.imshow(img)
        # plt.show()

        # log.info(train.head())
        # log.info(kaggle.camera_matrix)
        pred_string = train.iloc[0]['PredictionString']
        coords = kaggle.str2coords(pred_string)
        # log.info(coords)

        lens = [len(kaggle.str2coords(s)) for s in train['PredictionString']]

        ############
        plt.figure(figsize=(15, 6))
        sns.countplot(lens)
        # plt.xlabel('Number of cars in image')
        # plt.show()
        plt.savefig('eda/number_cars_in_image.png')

        ############
        plt.figure(figsize=(15, 6))
        sns.distplot(functools.reduce(lambda a, b: a + b,
                                      [[c['x'] for c in kaggle.str2coords(s)]
                                       for s in train['PredictionString']]),
                     bins=500)
        # sns.distplot([kaggle.str2coords(s)[0]['x'] for s in train['PredictionString']]);
        plt.xlabel('x')
        # plt.show()
        plt.savefig('eda/x.png')

        ############
        plt.figure(figsize=(15, 6))
        sns.distplot(functools.reduce(lambda a, b: a + b,
                                      [[c['y'] for c in kaggle.str2coords(s)]
                                       for s in train['PredictionString']]),
                     bins=500)
        plt.xlabel('y')
        # plt.show()
        plt.savefig('eda/y.png')

        ############
        plt.figure(figsize=(15, 6))
        sns.distplot(functools.reduce(lambda a, b: a + b,
                                      [[c['z'] for c in kaggle.str2coords(s)]
                                       for s in train['PredictionString']]),
                     bins=500)
        plt.xlabel('z')
        # plt.show()
        plt.savefig('eda/z.png')

        ############
        plt.figure(figsize=(15, 6))
        sns.distplot(
            functools.reduce(lambda a, b: a + b,
                             [[c['yaw'] for c in kaggle.str2coords(s)]
                              for s in train['PredictionString']]))
        plt.xlabel('yaw')
        # plt.show()
        plt.savefig('eda/yaw.png')

        ############
        plt.figure(figsize=(15, 6))
        sns.distplot(
            functools.reduce(lambda a, b: a + b,
                             [[c['roll'] for c in kaggle.str2coords(s)]
                              for s in train['PredictionString']]))
        plt.xlabel('roll')
        # plt.show()
        plt.savefig('eda/roll.png')

        ############
        plt.figure(figsize=(15, 6))
        sns.distplot(
            functools.reduce(lambda a, b: a + b,
                             [[c['pitch'] for c in kaggle.str2coords(s)]
                              for s in train['PredictionString']]))
        plt.xlabel('pitch')
        # plt.show()
        plt.savefig('eda/pitch.png')

        ############
        plt.figure(figsize=(15, 6))
        sns.distplot(
            functools.reduce(lambda a, b: a + b, [[
                kaggle.rotate(c['roll'], np.pi) for c in kaggle.str2coords(s)
            ] for s in train['PredictionString']]))
        plt.xlabel('roll rotated by pi')
        # plt.show()
        plt.savefig('eda/roll_rotated_by_pi.png')

        plt.figure(figsize=(14, 14))
        plt.imshow(
            imread(opj(cfg.train_images,
                       train.iloc[2217]['ImageId'] + '.jpg')))
        plt.scatter(*kaggle.get_img_coords(
            train.iloc[2217]['PredictionString']),
                    color='red',
                    s=100)
        # plt.show()
        # log.info(kaggle.get_img_coords(train.iloc[2217]['PredictionString']))

        ############
        xs, ys = [], []

        for ps in train['PredictionString']:
            x, y = kaggle.get_img_coords(ps)
            xs += list(x)
            ys += list(y)

        plt.figure(figsize=(18, 18))
        plt.imshow(imread(
            opj(cfg.train_images, train.iloc[2217]['ImageId'] + '.jpg')),
                   alpha=0.3)
        plt.scatter(xs, ys, color='red', s=10, alpha=0.2)
        # plt.show()
        plt.savefig('eda/xs-ys_distribution.png')

        ############
        # view distribution from the sky
        road_width = 3
        road_xs = [
            -road_width, road_width, road_width, -road_width, -road_width
        ]
        road_ys = [0, 0, 500, 500, 0]

        plt.figure(figsize=(16, 16))
        plt.axes().set_aspect(1)
        plt.xlim(-50, 50)
        plt.ylim(0, 100)

        # View road
        plt.fill(road_xs, road_ys, alpha=0.2, color='gray')
        plt.plot([road_width / 2, road_width / 2], [0, 100],
                 alpha=0.4,
                 linewidth=4,
                 color='white',
                 ls='--')
        plt.plot([-road_width / 2, -road_width / 2], [0, 100],
                 alpha=0.4,
                 linewidth=4,
                 color='white',
                 ls='--')

        # View cars
        # plt.scatter(points_df['x'], np.sqrt(points_df['z']**2 + points_df['y']**2), color='red', s=10, alpha=0.1)
        # plt.savefig('eda/view_from_sky.png')

        ############
        fig = px.scatter_3d(points_df,
                            x='x',
                            y='y',
                            z='z',
                            color='pitch',
                            range_x=(-50, 50),
                            range_y=(0, 50),
                            range_z=(0, 250),
                            opacity=0.1)
        # fig.show()

        zy_slope = LinearRegression()
        X = points_df[['z']]
        y = points_df[['y']]
        zy_slope.fit(X, y)
        print('MAE without x:', mean_absolute_error(y, zy_slope.predict(X)))

        # Will use this model later
        xzy_slope = LinearRegression()
        X = points_df[['x', 'z']]
        y = points_df['y']
        xzy_slope.fit(X, y)
        print('MAE with x:', mean_absolute_error(y, xzy_slope.predict(X)))
        print('\ndy/dx = {:.3f} \ndy/dz = {:.3f}'.format(*xzy_slope.coef_))

        plt.figure(figsize=(16, 16))
        plt.xlim(0, 500)
        plt.ylim(0, 100)
        plt.scatter(points_df['z'], points_df['y'], label='Real points')
        X_line = np.linspace(0, 500, 10)
        plt.plot(X_line,
                 zy_slope.predict(X_line.reshape(-1, 1)),
                 color='orange',
                 label='Regression')
        plt.legend()
        plt.xlabel('z coordinate')
        plt.ylabel('y coordinate')
        plt.savefig('eda/linear_regression.png')

        # 3d view
        n_rows = 6
        for idx in range(n_rows):
            fig, axes = plt.subplots(1, 2, figsize=(20, 20))
            img = imread(
                opj(cfg.train_images, train['ImageId'].iloc[idx] + '.jpg'))
            axes[0].imshow(img)
            img_vis = kaggle.visualize(
                img, kaggle.str2coords(train['PredictionString'].iloc[idx]))
            axes[1].imshow(img_vis)
            # plt.show()
            plt.savefig(f'eda/img-view_coords_{idx}.png')

    if 0:
        img0 = imread(opj(cfg.train_images, train.iloc[0]['ImageId'] + '.jpg'))
        img = kaggle.preprocess_image(img0)

        print(train.iloc[0]['PredictionString'])
        mask, regr = kaggle.get_mask_and_regr(
            img0, train.iloc[0]['PredictionString'])
        # print('img.shape', img.shape, 'std:', np.std(img))
        # print('mask.shape', mask.shape, 'std:', np.std(mask))
        # print('regr.shape', regr.shape, 'std:', np.std(regr))

        plt.figure(figsize=(16, 16))
        plt.title('Processed image')
        plt.imshow(img)
        # plt.show()
        plt.savefig('eda/processed_image.png')

        plt.figure(figsize=(16, 16))
        plt.title('Detection Mask')
        plt.imshow(mask)
        # plt.show()
        plt.savefig('eda/detection_mask.png')

        plt.figure(figsize=(16, 16))
        plt.title('Yaw values')
        plt.imshow(regr[:, :, -2])
        # plt.show()
        plt.savefig('eda/yaw_values.png')

    #############
    if 0:
        regr_model = kaggle.get_regr_model(train)

        for idx in range(2):
            fig, axes = plt.subplots(1, 2, figsize=(20, 20))

            for ax_i in range(2):
                img0 = imread(
                    opj(cfg.train_images, train['ImageId'].iloc[idx] + '.jpg'))
                if ax_i == 1:
                    img0 = img0[:, ::-1]
                img = kaggle.preprocess_image(img0, ax_i == 1)
                mask, regr = kaggle.get_mask_and_regr(
                    img0, train['PredictionString'][idx], ax_i == 1)
                regr = np.rollaxis(regr, 2, 0)
                coords = kaggle.extract_coords(
                    np.concatenate([mask[None], regr], 0), regr_model,
                    ax_i == 1)

                axes[ax_i].set_title('Flip = {}'.format(ax_i == 1))
                axes[ax_i].imshow(kaggle.visualize(img0, coords))
            # plt.show()
            plt.savefig(f'eda/{idx}_{ax_i}.png')

    if 0:
        dataset = dataset_factory.CarDataset(cfg.data.train)
        img, mask, regr = dataset[0]

        plt.figure(figsize=(16, 16))
        plt.imshow(np.rollaxis(img, 0, 3))
        # plt.show()
        plt.savefig(f'eda/img.png')

        plt.figure(figsize=(16, 16))
        plt.imshow(mask)
        # plt.show()
        plt.savefig(f'eda/mask.png')

        plt.figure(figsize=(16, 16))
        plt.imshow(regr[:, :, -2])
        # plt.show()
        plt.savefig(f'eda/regr.png')

    #########
    if 1:
        # initial -----------------------------------
        best = {
            'loss': float('inf'),
            'score': 0.0,
            'epoch': -1,
        }

        train_loader = dataset_factory.get_dataloader(cfg.data.train)
        valid_loader = dataset_factory.get_dataloader(cfg.data.valid)
        test_loader = dataset_factory.get_dataloader(cfg.data.test)
        for i, (img, mask, regr) in enumerate(tqdm(test_loader)):
            print(i)
            if i == 3:
                break

        model = model_factory.get_model(cfg)
        optimizer = optimizer_factory.get_optimizer(model, cfg)
        scheduler = scheduler_factory.get_scheduler(cfg, optimizer,
                                                    best['epoch'])
Exemplo n.º 5
0
def run(config_file):
    config = load_config(config_file)  
    config.work_dir = '/home/koga/workspace/kaggle_bengali/result/'+config.work_dir
    os.makedirs(config.work_dir, exist_ok=True)
    os.makedirs(config.work_dir + "/checkpoints", exist_ok=True)
    print('working directory:', config.work_dir)
    logger = get_logger(config.work_dir+"log.txt")
    
    all_transforms = {}
    all_transforms['train'] = Transform(
            size=config.data.image_size,
            affine=config.transforms.affine,
            autoaugment_ratio=config.transforms.autoaugment_ratio,
            threshold=config.transforms.threshold,
            sigma=config.transforms.sigma,
            blur_ratio=config.transforms.blur_ratio,
            noise_ratio=config.transforms.noise_ratio,
            cutout_ratio=config.transforms.cutout_ratio,
            grid_distortion_ratio=config.transforms.grid_distortion_ratio,
            random_brightness_ratio=config.transforms.random_brightness_ratio,
            piece_affine_ratio=config.transforms.piece_affine_ratio,
            ssr_ratio=config.transforms.ssr_ratio,
            grid_mask_ratio=config.transforms.grid_mask_ratio,
            augmix_ratio=config.transforms.augmix_ratio,
    )
    all_transforms['valid'] = Transform(size=config.data.image_size)

    dataloaders = {
        phase: make_loader(
            phase=phase,
            df_path=config.train.dfpath,
            batch_size=config.train.batch_size,
            num_workers=config.num_workers,
            idx_fold=config.data.params.idx,
            fold_csv=config.data.params.fold_csv,
            transforms=all_transforms[phase],
            # debug=config.debug
            crop=config.transforms.crop
        )
        for phase in ['train', 'valid']
    }
    model_root = MODEL_LIST['Resnet34_3model'](pretrained=config.model.pretrained, out_dim=168)
    model_vowel = MODEL_LIST['Resnet34_3model'](pretrained=config.model.pretrained, out_dim=11)
    model_const = MODEL_LIST['Resnet34_3model'](pretrained=config.model.pretrained, out_dim=7)

    model_root = model_root.to(device)
    model_vowel = model_vowel.to(device)
    model_const = model_const.to(device)

    model_list = [model_root, model_vowel, model_const]

    criterion = get_criterion(config)
    optimizer_root = get_optimizer(config, model_root)
    optimizer_vowel = get_optimizer(config, model_vowel)
    optimizer_const = get_optimizer(config, model_const)

    optimizer_list = [optimizer_root, optimizer_vowel, optimizer_const]

    scheduler_root = get_scheduler(optimizer_root, config)
    scheduler_vowel = get_scheduler(optimizer_vowel, config)
    scheduler_const = get_scheduler(optimizer_const, config)

    scheduler_list = [scheduler_root, scheduler_vowel, scheduler_const]

    accumlate_step = 1
    if config.train.accumulation_size > 0:
        accumlate_step = config.train.accumulation_size // config.train.batch_size

    best_valid_recall = 0.0
    if config.train.resume:
        print('resume checkpoints')
        checkpoint = torch.load("/home/koga/workspace/kaggle_bengali/result/" + config.train.path)
        model.load_state_dict(fix_model_state_dict(checkpoint['checkpoint']))



    valid_recall = 0.0

    for epoch in range(1, config.train.num_epochs+1):
        print(f'epoch {epoch} start')
        logger.info(f'epoch {epoch} start ')

        metric_train = do_train(model_list, dataloaders["train"], criterion, optimizer_list, device, config, epoch, accumlate_step)
        torch.cuda.empty_cache()
        metrics_eval = do_eval(model_list, dataloaders["valid"], criterion, device)
        torch.cuda.empty_cache()
        valid_recall = metrics_eval["valid_recall"]

        scheduler_list[0].step(metrics_eval["valid_recall"])
        scheduler_list[1].step(metrics_eval["valid_recall"])
        scheduler_list[2].step(metrics_eval["valid_recall"])


        print(f'epoch: {epoch} ', metric_train, metrics_eval)
        logger.info(f'epoch: {epoch} {metric_train} {metrics_eval}')
        if valid_recall > best_valid_recall:
            print(f"save checkpoint: best_recall:{valid_recall}")
            logger.info(f"save checkpoint: best_recall:{valid_recall}")
            torch.save({
                'checkpoint_root': model_list[0].state_dict(),
                'checkpoint_vowel': model_list[1].state_dict(),
                'checkpoint_const': model_list[2].state_dict(),
                'epoch': epoch,
                'best_valid_recall': valid_recall,
                }, config.work_dir + "/checkpoints/" + f"{epoch}.pth")
            best_valid_recall = valid_recall

        torch.cuda.empty_cache()
        gc.collect()
Exemplo n.º 6
0
def do_train(cfg, model):
    # get criterion -----------------------------
    criterion = criterion_factory.get_criterion(cfg)

    # get optimization --------------------------
    optimizer = optimizer_factory.get_optimizer(model, cfg)

    # initial -----------------------------------
    best = {
        'loss': float('inf'),
        'score': 0.0,
        'epoch': -1,
    }

    # resume model ------------------------------
    if cfg.resume_from:
        log.info('\n')
        log.info(f're-load model from {cfg.resume_from}')
        detail = util.load_model(cfg.resume_from, model, optimizer, cfg.device)
        best.update({
            'loss': detail['loss'],
            'score': detail['score'],
            'epoch': detail['epoch'],
        })

    # scheduler ---------------------------------
    scheduler = scheduler_factory.get_scheduler(cfg, optimizer, best['epoch'])

    # fp16 --------------------------------------
    if cfg.apex:
        amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

    # setting dataset ---------------------------
    loader_train = dataset_factory.get_dataloader(cfg.data.train)
    loader_valid = dataset_factory.get_dataloader(cfg.data.valid)

    # start trainging ---------------------------
    start_time = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
    log.info('\n')
    log.info(f'** start train [fold{cfg.fold}th] {start_time} **\n')
    log.info(
        'epoch    iter      rate     | smooth_loss/score | valid_loss/score | best_epoch/best_score |  min'
    )
    log.info(
        '-------------------------------------------------------------------------------------------------'
    )

    for epoch in range(best['epoch'] + 1, cfg.epoch):
        end = time.time()
        util.set_seed(epoch)

        ## train model --------------------------
        train_results = run_nn(cfg.data.train,
                               'train',
                               model,
                               loader_train,
                               criterion=criterion,
                               optimizer=optimizer,
                               apex=cfg.apex,
                               epoch=epoch)

        ## valid model --------------------------
        with torch.no_grad():
            val_results = run_nn(cfg.data.valid,
                                 'valid',
                                 model,
                                 loader_valid,
                                 criterion=criterion,
                                 epoch=epoch)

        detail = {
            'score': val_results['score'],
            'loss': val_results['loss'],
            'epoch': epoch,
        }

        if val_results['loss'] <= best['loss']:
            best.update(detail)
            util.save_model(model, optimizer, detail, cfg.fold[0],
                            os.path.join(cfg.workdir, 'checkpoint'))


        log.info('%5.1f   %5d    %0.6f   |  %0.4f  %0.4f  |  %0.4f  %6.4f |  %6.1f     %6.4f    | %3.1f min' % \
                (epoch+1, len(loader_train), util.get_lr(optimizer), train_results['loss'], train_results['score'], val_results['loss'], val_results['score'], best['epoch'], best['score'], (time.time() - end) / 60))

        scheduler.step(
            val_results['loss'])  # if scheduler is reducelronplateau
        # scheduler.step()

        # early stopping-------------------------
        if cfg.early_stop:
            if epoch - best['epoch'] > cfg.early_stop:
                log.info(f'=================================> early stopping!')
                break
        time.sleep(0.01)
Exemplo n.º 7
0
def train(config_yml, working_root=str(this_file_dir / '..')):
    """画像分類モデルを学習するエンドポイント

    Args:
        config_yml (str): コンフィグ用のyamlファイル
        working_root (str, optional): どこを起点としてデータを参照するか. Defaults to str(this_file_dir / '..').
    """

    with open(config_yml, 'r') as f:
        config = yaml.safe_load(f)

    # ====
    # データ用意
    # ====
    train_loader = get_dataloader_through_dataset(config['data']['train'],
                                                  working_root)
    test_loader = get_dataloader_through_dataset(
        config['data']['eval'],
        working_root,
    )

    # ネットワーク用意
    net = get_model(config['model'])
    if config['model'].get('model_state_dict'):
        model_state_dict_path = Path(
            working_root) / config['model']['model_state_dict']
        load_weight(net, str(model_state_dict_path))

    # optimizer定義
    optimizer = get_optimizer(net, config['optimizer'])

    # 損失関数定義
    criterion = get_loss(config['loss'])

    # データを保存する機能を持つオブジェクト
    datasaver = DataSaver(config['output_data'])
    datasaver.save_config(config_yml)

    # ======
    # メインループ
    # ======
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if 'cuda' in config:
        device = 'cuda' if config['cuda'] == True else 'cpu'

    net.to(device)

    num_epochs = config['num_epochs']
    for epoch in range(num_epochs):
        print(epoch)
        metrics_dict = {}

        #train
        print('train phase')
        metrics = run_train(net, train_loader, criterion, optimizer, device)
        metrics_dict.update(metrics)

        # eval
        print('eval phase')
        metrics, result_detail = run_eval(net, test_loader, criterion, device)
        metrics_dict.update(metrics)

        # 評価指標の記録
        datasaver.save_metrics(metrics_dict, epoch)
        datasaver.save_model(net, epoch)
        datasaver.save_result_detail(result_detail, epoch)