示例#1
0
def main():

    parser = argparse.ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument("--batch-size", type=int, default=2)

    args = parser.parse_args()

    datamodule = IRModule.load()
    datamodule.batch_size = args.batch_size

    model = Model.from_tinybert()

    if COMET_INSTALLED:
        comet_logger = CometLogger(
            api_key=os.environ.get("COMET_API_KEY"),
            experiment_name="mtg-search",
            log_graph=False,
            log_code=False,
            log_env_details=False,
            disabled=True,
        )
        comet_logger.log_hyperparams(asdict(model.config))
        key = comet_logger.experiment.get_key()
    else:
        key = uuid.uuid4().hex
        comet_logger = True  # to pass logger=True to Trainer

    model.config.key = key

    callbacks = [
        ModelCheckpoint(
            dirpath=MODELS_DIR,
            save_top_k=1,
            monitor="val_acc",
            filename=key,
        )
    ]

    trainer = Trainer.from_argparse_args(
        args,
        logger=comet_logger,
        callbacks=callbacks,
        num_sanity_val_steps=0,
        val_check_interval=10,
    )

    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)
示例#2
0
def run_resnet_synth(network_class: AbstractResnet, num_input_layers, num_outputs,
                     comment, train_db_path, val_db_path, val_split, transform, out_transform, mean_error_func_dict,
                     hist_error_func_dict, text_error_func_dict, output_scaling=1e4, lr=1e-2,
                     resnet_type='18', train_cache_size=5500, val_cache_size=1000, batch_size=64, num_epochs=1000,
                     weight_decay=0, cosine_annealing_steps=10, loss_func=F.smooth_l1_loss, subsampler_size=640,
                     dtype=torch.float32, track_ideal_metrics=False, monitor_metric_name=None, parallel_plotter=None,
                     extra_params=None):
    if extra_params is None:
        extra_params = {}
    if None in [batch_size, num_epochs, resnet_type, train_db_path, val_db_path, val_split, comment]:
        raise ValueError('Config not fully initialized')
    torch_to_np_dtypes = {
        torch.float16: np.float16,
        torch.float32: np.float32,
        torch.float64: np.float64
    }
    np_dtype = torch_to_np_dtypes[dtype]
    transform = Functor(transform)
    logged_hyperparams = {'batch_size': batch_size, 'train_db': train_db_path.split('/')[-1],
                          'val_db': val_db_path.split('/')[-1], 'train-val_split_index': val_split,
                          'loss_func': loss_func.__name__, 'num_outputs': num_outputs,
                          'output_scaling': output_scaling, 'resnet_type': resnet_type, 'lr': lr,
                          'weight_decay': weight_decay, 'cosine_annealing_steps': cosine_annealing_steps,
                          'transform': transform}
    with h5py.File(train_db_path, 'r') as hf:
        db_size = hf['data']['images'].len()
    if isinstance(val_split, int):
        train_dset = ImageDataset(train_db_path,
                                  transform=transform, output_scaling=output_scaling, out_transform=out_transform,
                                  cache_size=train_cache_size,
                                  max_index=val_split, dtype=np_dtype)
        val_dset = ImageDataset(val_db_path,
                                transform=transform, output_scaling=output_scaling, out_transform=out_transform,
                                cache_size=val_cache_size,
                                min_index=val_split, dtype=np_dtype)
    else:
        train_split = set(range(db_size))
        train_split -= set(val_split)
        train_split = tuple(train_split)
        train_dset = ImageDataset(train_db_path,
                                  transform=transform, output_scaling=output_scaling, out_transform=out_transform,
                                  cache_size=train_cache_size,
                                  index_list=train_split, dtype=np_dtype)
        val_dset = ImageDataset(val_db_path,
                                transform=transform, output_scaling=output_scaling, out_transform=out_transform,
                                cache_size=val_cache_size,
                                index_list=val_split, dtype=np_dtype)
    train_loader = DataLoader(train_dset, batch_size, shuffle=False, num_workers=4,
                              sampler=SubsetChoiceSampler(subsampler_size, len(train_dset)))
    val_loader = DataLoader(val_dset, batch_size, shuffle=True, num_workers=4)
    model = network_class(num_input_layers, num_outputs, loss_func=loss_func, output_scaling=output_scaling,
                          reduce_error_func_dict=mean_error_func_dict,
                          hist_error_func_dict=hist_error_func_dict,
                          text_error_func_dict=text_error_func_dict,
                          resnet_type=resnet_type, learning_rate=lr,
                          cosine_annealing_steps=10, weight_decay=weight_decay, dtype=dtype,
                          track_ideal_metrics=track_ideal_metrics, **extra_params)

    model.plotter = parallel_plotter
    logger = CometLogger(api_key="sjNiwIhUM0j1ufNwaSjEUHHXh", project_name="AeroVision",
                         experiment_name=comment)
    logger.log_hyperparams(params=logged_hyperparams)
    checkpoints_folder = f"./checkpoints/{comment}/"
    if os.path.isdir(checkpoints_folder):
        shutil.rmtree(checkpoints_folder)
    else:
        Path(checkpoints_folder).mkdir(parents=True, exist_ok=True)
    if monitor_metric_name:
        dirpath = f"checkpoints/{comment}"
        dirpath_path = Path(dirpath)
        if dirpath_path.exists():
            shutil.rmtree(dirpath_path)
        mcp = ModelCheckpoint(
            dirpath=dirpath,
            filename="{epoch}_{" + monitor_metric_name + ":.3e}",
            save_last=True,
            save_top_k=1,
            period=1,
            monitor=monitor_metric_name,
            verbose=True)
        callbacks = [mcp]
    else:
        callbacks = None

    trainer = pl.Trainer(gpus=1, max_epochs=num_epochs,
                         callbacks=callbacks,
                         num_sanity_val_steps=0,
                         profiler=True, logger=logger)
    trainer.fit(model, train_loader, val_loader)
    logger.experiment.log_asset_folder(checkpoints_folder)
示例#3
0
def train(**params):
    params = EasyDict(params)
    seed_everything(params.seed)

    config = ConfigParser()
    config.read('config.ini')

    logger, callbacks = False, list()
    if params.logger:
        comet_config = EasyDict(config['cometml'])
        logger = CometLogger(api_key=comet_config.apikey,
                             project_name=comet_config.projectname,
                             workspace=comet_config.workspace)
        logger.experiment.set_code(filename='project/span_bert/train.py',
                                   overwrite=True)
        logger.log_hyperparams(params)
        logger.experiment.log_asset_folder('project/span_bert')
        callbacks.append(LearningRateMonitor(logging_interval='epoch'))

    model_checkpoint = ModelCheckpoint(
        filepath=
        'checkpoints/{epoch:02d}-{val_loss:.4f}-{f1_spans_sentence:.4f}',
        save_weights_only=True,
        save_top_k=10,
        monitor='val_loss',
        mode='min',
        period=1)
    callbacks.extend([model_checkpoint])

    model_data = {
        'bert':
        [BertForTokenClassification, BertTokenizerFast, 'bert-base-uncased'],
        'albert':
        [AlbertForTokenClassification, AlbertTokenizerFast, 'albert-base-v2'],
        'electra': [
            ElectraForTokenClassification, ElectraTokenizerFast,
            'google/electra-small-discriminator'
        ],
        'roberta':
        [RobertaForTokenClassification, RobertaTokenizerFast, 'roberta-base'],
        'xlnet':
        [XLNetForTokenClassification, XLNetTokenizerFast, 'xlnet-base-cased'],
        'mobilebert': [
            MobileBertForTokenClassification, MobileBertTokenizerFast,
            'google/mobilebert-uncased'
        ],
        'squeezebert': [
            SqueezeBertForTokenClassification, SqueezeBertTokenizerFast,
            'squeezebert/squeezebert-mnli-headless'
        ]
    }
    model_class, tokenizer_class, model_name = model_data[params.model]
    tokenizer = tokenizer_class.from_pretrained(model_name, do_lower_case=True)
    model_backbone = model_class.from_pretrained(model_name,
                                                 num_labels=2,
                                                 output_attentions=False,
                                                 output_hidden_states=False)

    data_module = DatasetModule(data_dir=params.data_path,
                                tokenizer=tokenizer,
                                batch_size=params.batch_size,
                                length=params.length,
                                augmentation=params.augmentation,
                                valintrain=params.valintrain)
    model = LitModule(model=model_backbone,
                      tokenizer=tokenizer,
                      freeze=params.freeze,
                      lr=params.lr)

    trainer = Trainer(logger=logger,
                      max_epochs=params.epochs,
                      callbacks=callbacks,
                      gpus=1,
                      deterministic=True,
                      fast_dev_run=params.fast_dev_run)

    if params.find_lr:
        lr_finder = trainer.tuner.lr_find(model,
                                          datamodule=data_module,
                                          min_lr=1e-5,
                                          max_lr=1e-4)
        model.learning_rate = lr_finder.suggestion()

    trainer.fit(model, datamodule=data_module)

    if params.pseudolabel:
        pseudolabel_data = pd.read_csv(
            'data/civil_comments/all_civil_data_512.csv')
        already_labeled = pd.DataFrame(
            columns=[*pseudolabel_data.columns, 'spans'])
        model = LitModule.load_from_checkpoint(
            checkpoint_path=model_checkpoint.best_model_path,
            model=model_backbone,
            tokenizer=tokenizer,
            freeze=params.freeze,
            lr=2e-5,
            scheduler=False)

        while len(pseudolabel_data) > 0:
            size_to_label = min(len(data_module.train_df),
                                len(pseudolabel_data))
            df_subset = pseudolabel_data.sample(size_to_label)
            pseudolabel_data = pseudolabel_data.drop(df_subset.index)

            df_subset['spans'] = model.predict_dataframe(
                df_subset, params.length)

            already_labeled = pd.concat([already_labeled, df_subset])
            data_module = DatasetModule(data_dir=params.data_path,
                                        tokenizer=tokenizer,
                                        batch_size=params.batch_size,
                                        length=params.length,
                                        augmentation=params.augmentation,
                                        valintrain=params.valintrain,
                                        injectdataset=already_labeled)

            trainer = Trainer(logger=logger,
                              max_epochs=params.epochs,
                              callbacks=callbacks,
                              gpus=1,
                              deterministic=True,
                              val_check_interval=0.5,
                              fast_dev_run=params.fast_dev_run)
            trainer.fit(model, datamodule=data_module)

    if params.logger:
        for absolute_path in model_checkpoint.best_k_models.keys():
            logger.experiment.log_model(
                Path(absolute_path).name, absolute_path)
        logger.log_metrics(
            {'best_model_score': model_checkpoint.best_model_score.tolist()})

        best_model = LitModule.load_from_checkpoint(
            checkpoint_path=model_checkpoint.best_model_path,
            model=model_backbone,
            tokenizer=tokenizer,
            freeze=params.freeze,
            lr=params.lr)

        predicted_df = best_model.predict_dataframe(data_module.test_df,
                                                    params.length)
        log_predicted_spans(predicted_df, logger)
示例#4
0
def main(cfg: DictConfig):
    print(f'Training {cfg.train.model} Model')
    cur_dir = hydra.utils.get_original_cwd()
    os.chdir(cur_dir)

    # Data Augmentation  --------------------------------------------------------
    transform = ImageTransform(cfg) if cfg.train.model != 'progan' else None

    # DataModule  ---------------------------------------------------------------
    dm = None
    data_dir = './data'
    if cfg.train.data == 'celeba_hq':
        img_paths = glob.glob(os.path.join(data_dir, 'celeba_hq', '**/*.jpg'),
                              recursive=True)
        dm = SingleImageDataModule(img_paths, transform, cfg)

    elif cfg.train.data == 'afhq':
        img_paths = glob.glob(os.path.join(data_dir, 'afhq', '**/*.jpg'),
                              recursive=True)
        dm = SingleImageDataModule(img_paths, transform, cfg)

    elif cfg.train.data == 'ffhq':
        img_paths = glob.glob(os.path.join(data_dir, 'ffhq', '**/*.png'),
                              recursive=True)
        dm = SingleImageDataModule(img_paths, transform, cfg)

    # Model  --------------------------------------------------------------------
    nets = build_model(cfg.train.model, cfg)

    # Comet_ml  -----------------------------------------------------------------
    load_dotenv('.env')
    logger = CometLogger(api_key=os.environ['COMET_ML_API_KEY'],
                         project_name=os.environ['COMET_ML_PROJECT_NAME'],
                         experiment_name=f"{cfg.train.model}")

    logger.log_hyperparams(dict(cfg.train))

    # Lightning Module  ---------------------------------------------------------
    model = None
    checkpoint_path = 'checkpoints/'
    checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path,
                                          filename='{epoch:02d}',
                                          prefix=cfg.train.model,
                                          period=1)

    if cfg.train.model == 'vae':
        model = VAE_LightningSystem(nets[0], cfg)

    elif cfg.train.model == 'dcgan':
        model = DCGAN_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    elif cfg.train.model == 'wgan_gp':
        logger.log_hyperparams(dict(cfg.wgan_gp))
        model = WGAN_GP_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    elif cfg.train.model == 'cyclegan':
        logger.log_hyperparams(dict(cfg.cyclegan))
        data_dir = 'data/'
        base_img_paths = glob.glob(os.path.join(data_dir,
                                                cfg.cyclegan.base_imgs_dir,
                                                '**/*.jpg'),
                                   recursive=True)
        style_img_paths = glob.glob(os.path.join(data_dir,
                                                 cfg.cyclegan.style_imgs_dir,
                                                 '**/*.jpg'),
                                    recursive=True)
        dm = CycleGANDataModule(base_img_paths,
                                style_img_paths,
                                transform,
                                cfg,
                                phase='train',
                                seed=cfg.train.seed)
        model = CycleGAN_LightningSystem(nets[0], nets[1], nets[2], nets[3],
                                         transform, cfg, checkpoint_path)

    elif cfg.train.model == 'sagan':
        logger.log_hyperparams(dict(cfg.sagan))
        model = SAGAN_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    elif cfg.train.model == 'progan':
        logger.log_hyperparams(dict(cfg.progan))
        model = PROGAN_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    # Trainer  ---------------------------------------------------------
    trainer = Trainer(
        logger=logger,
        max_epochs=cfg.train.epoch,
        gpus=1,
        callbacks=[checkpoint_callback],
        # fast_dev_run=True,
        # resume_from_checkpoint='./checkpoints/sagan-epoch=11.ckpt'
    )

    # Train
    trainer.fit(model, datamodule=dm)
示例#5
0
pl.seed_everything(args.seed)

############################################
# Comet.ml
############################################
# comet.ml
comet_logger = CometLogger(
        api_key=os.environ.get('COMET_API_KEY'),
        workspace="test",  # Optional
        project_name='wncg',  # Optional
    ) \
    if args.comet_logger else None

# save hyperparams to comet.ml
if args.comet_logger is not None:
    comet_logger.log_hyperparams(dict(args.__dict__))

# make directory if it doesn't exist
model_dir = args.model_dir
if comet_logger is not None:
    if not os.path.exists(os.path.join(model_dir, comet_logger.version)):
        os.makedirs(os.path.join(model_dir, comet_logger.version))
        model_dir = os.path.join(model_dir, comet_logger.version)

############################################
# Dataset and Vocabulary
############################################
# construct vocabularies
vocab = WNCGVocabulary(args.dataset_dir)
# define train/valiod/test dataset
train_dataset = WNCGDataset(args.dataset_dir,