コード例 #1
0
ファイル: model.py プロジェクト: zeta1999/keraTorch
 def lr_find(self, x, y, bs):
     db = create_db(x, y, bs=bs)
     learn = Learner(db, self.model, loss_func=self.loss)
     learn.lr_find()
     clear_output()
     learn.recorder.plot(suggestion=True)
コード例 #2
0
                      collate_fn=dlc.gdf_col,
                      pin_memory=False,
                      num_workers=0)

databunch = DataBunch(t_data, v_data, collate_fn=dlc.gdf_col, device="cuda")
t_final = time() - start
print(t_final)
print("Creating model")
start = time()
model = TabularModel(emb_szs=embeddings,
                     n_cont=len(cont_names),
                     out_sz=2,
                     layers=[512, 256])
learn = Learner(databunch, model, metrics=[accuracy])
learn.loss_func = torch.nn.CrossEntropyLoss()
t_final = time() - start
print(t_final)
print("Finding learning rate")
start = time()
learn.lr_find()
learn.recorder.plot(show_moms=True, suggestion=True)
learning_rate = 1.32e-2
epochs = 1
t_final = time() - start
print(t_final)
print("Running Training")
start = time()
learn.fit_one_cycle(epochs, learning_rate)
t_final = time() - start
print(t_final)
コード例 #3
0
ファイル: train.py プロジェクト: lewfish/mlx
def train(config_path, opts):
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    cfg = load_config(config_path, opts)
    print(cfg)

    # Setup data
    databunch, full_databunch = build_databunch(cfg, tmp_dir)
    output_dir = setup_output_dir(cfg, tmp_dir)
    print(full_databunch)

    plotter = build_plotter(cfg)
    if not cfg.lr_find_mode and not cfg.predict_mode:
        plotter.plot_data(databunch, output_dir)

    # Setup model
    num_labels = databunch.c
    model = build_model(cfg, num_labels)
    metrics = [CocoMetric(num_labels)]
    learn = Learner(databunch, model, path=output_dir, metrics=metrics)
    fastai.basic_train.loss_batch = loss_batch
    best_model_path = join(output_dir, 'best_model.pth')
    last_model_path = join(output_dir, 'last_model.pth')

    # Train model
    callbacks = [
        MyCSVLogger(learn, filename='log'),
        SubLossMetric(learn, model.subloss_names)
    ]

    if cfg.output_uri.startswith('s3://'):
        callbacks.append(
            SyncCallback(output_dir, cfg.output_uri, cfg.solver.sync_interval))

    if cfg.model.init_weights:
        device = next(model.parameters()).device
        model.load_state_dict(
            torch.load(cfg.model.init_weights, map_location=device))

    if not cfg.predict_mode:
        if cfg.overfit_mode:
            learn.fit_one_cycle(cfg.solver.num_epochs, cfg.solver.lr, callbacks=callbacks)
            torch.save(learn.model.state_dict(), best_model_path)
            learn.model.eval()
            print('Validating on training set...')
            learn.validate(full_databunch.train_dl, metrics=metrics)
        else:
            tb_logger = TensorboardLogger(learn, 'run')
            tb_logger.set_extra_args(
                model.subloss_names, cfg.overfit_mode)

            extra_callbacks = [
                MySaveModelCallback(
                    learn, best_model_path, monitor='coco_metric', every='improvement'),
                MySaveModelCallback(learn, last_model_path, every='epoch'),
                TrackEpochCallback(learn),
            ]
            callbacks.extend(extra_callbacks)
            if cfg.lr_find_mode:
                learn.lr_find()
                learn.recorder.plot(suggestion=True, return_fig=True)
                lr = learn.recorder.min_grad_lr
                print('lr_find() found lr: {}'.format(lr))
                exit()

            learn.fit_one_cycle(cfg.solver.num_epochs, cfg.solver.lr, callbacks=callbacks)
            print('Validating on full validation set...')
            learn.validate(full_databunch.valid_dl, metrics=metrics)
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model.load_state_dict(
            torch.load(join(output_dir, 'best_model.pth'), map_location=device))
        model.eval()
        plot_dataset = databunch.train_ds

    print('Plotting predictions...')
    plot_dataset = databunch.train_ds if cfg.overfit_mode else databunch.valid_ds
    plotter.make_debug_plots(plot_dataset, model, databunch.classes, output_dir)
    if cfg.output_uri.startswith('s3://'):
        sync_to_dir(output_dir, cfg.output_uri)