Пример #1
0
def main(args):
    set_seed(args.seed)

    # load data
    loaders = (get_data_loader(args=args,
                               data_path=args.data_path,
                               bert_path=args.bert_path,
                               batch_size=args.batch_size,
                               num_workers=args.num_workers,
                               split=split) for split in ['train', 'dev'])
    trn_loader, dev_loader = loaders

    # initialize model
    model = MultimodalTransformer(
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        n_classes=args.n_classes,
        only_audio=args.only_audio,
        only_text=args.only_text,
        d_audio_orig=args.n_mfcc,
        d_text_orig=768,  # BERT hidden size
        d_model=args.d_model,
        attn_dropout=args.attn_dropout,
        relu_dropout=args.relu_dropout,
        emb_dropout=args.emb_dropout,
        res_dropout=args.res_dropout,
        out_dropout=args.out_dropout,
        attn_mask=args.attn_mask).to(args.device)

    # warmup scheduling
    args.total_steps = round(len(trn_loader) * args.epochs)
    args.warmup_steps = round(args.total_steps * args.warmup_percent)

    # optimizer & scheduler
    optimizer, scheduler = get_optimizer_and_scheduler(args, model)

    logging.info('training starts')
    model.zero_grad()
    args.global_step = 0
    for epoch in tqdm(range(1, args.epochs + 1), desc='epochs'):

        # training and evaluation steps
        train(args, model, trn_loader, optimizer, scheduler)
        loss, f1 = evaluate(model, dev_loader, args.device)

        # save model
        model_name = "epoch{}-loss{:.4f}-f1{:.4f}.bin".format(epoch, loss, f1)
        model_path = os.path.join(args.save_path, model_name)
        torch.save(model.state_dict(), model_path)

    logging.info('training ended')
Пример #2
0
def main():
    parser = ArgumentParser()

    #任务配置
    parser.add_argument('-device', default=0, type=int)
    parser.add_argument('-output_name', default='', type=str)
    parser.add_argument('-saved_model_path', default='', type=str) #如果是k fold合并模型进行预测,只需设置为对应k_fold模型对应的output path
    parser.add_argument('-type', default='train', type=str)
    parser.add_argument('-k_fold', default=-1, type=int) #如果是-1则说明不采用k折,否则说明采用k折的第几折
    parser.add_argument('-merge_classification', default='avg', type=str) # 个数预测:vote则采用投票法,avg则是平均概率
    parser.add_argument('-merge_with_bert_sort', default='yes', type=str) # 是否融合之前bert模型计算的相似度
    parser.add_argument('-k_fold_cache', default='no', type=str) #是否使用之前k_fold的cache
    parser.add_argument('-generate_candidates', default='', type=str) # 是否融合之前bert模型计算的相似度
    parser.add_argument('-seed', default=123456, type=int) # 随机数种子
    parser.add_argument('-cls_position', default='zero', type=str) # 添加的两个cls的position是否使用0
    parser.add_argument('-pretrained_model_path', default='/home/liangming/nas/lm_params/chinese_L-12_H-768_A-12/', type=str) # bert参数地址

    #训练参数
    parser.add_argument('-train_batch_size', default=64, type=int)
    parser.add_argument('-val_batch_size', default=256, type=int)
    parser.add_argument('-lr', default=2e-5, type=float)
    parser.add_argument('-epoch_num', default=20, type=int)

    parser.add_argument('-max_len', default=64, type=int)
    parser.add_argument('-dropout', default=0.3, type=float)
    parser.add_argument('-print_loss_step', default=2, type=int)
    parser.add_argument('-hit_list', default=[2, 5, 7, 10], type=list)

    args = parser.parse_args()
    # assert args.train_batch_size % args.neg_num == 0, print('batch size应该是neg_num的整数倍')

    #定义时间格式
    DATE_FORMAT = "%Y-%m-%d-%H:%M:%S"
    #定义输出文件夹,如果不存在则创建, 
    if args.output_name == '':
        output_path = os.path.join('./output/rerank_keywords_output', time.strftime(DATE_FORMAT,time.localtime(time.time())))
    else:
        output_path = os.path.join('./output/rerank_keywords_output', args.output_name)
        # if os.path.exists(output_path):
            # raise Exception('the output path {} already exists'.format(output_path))
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    #配置tensorboard    
    tensor_board_log_path = os.path.join(output_path, 'tensor_board_log{}'.format('' if args.k_fold == -1 else args.k_fold))
    writer = SummaryWriter(tensor_board_log_path)

    #定义log参数
    logger = Logger(output_path,'main{}'.format('' if args.k_fold == -1 else args.k_fold)).logger

    #设置seed
    logger.info('set seed to {}'.format(args.seed))
    set_seed(args)
    
    #打印args
    print_args(args, logger)

    #读取数据
    logger.info('#' * 20 + 'loading data and model' + '#' * 20)
    data_path = os.path.join(project_path, 'candidates')
    # data_path = os.path.join(project_path, 'tf_idf_candidates')
    train_list, val_list, test_list, code_to_name, name_to_code, standard_name_list = read_rerank_data(data_path, logger, args)

    #load model
    # pretrained_model_path = '/home/liangming/nas/lm_params/chinese_L-12_H-768_A-12/'
    pretrained_model_path = args.pretrained_model_path
    bert_config, bert_tokenizer, bert_model = get_pretrained_model(pretrained_model_path, logger)

    #获取dataset
    logger.info('create dataloader')
    train_dataset = RerankKeywordDataset(train_list, bert_tokenizer, args, logger)
    train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=False, collate_fn=train_dataset.collate_fn)

    val_dataset = RerankKeywordDataset(val_list, bert_tokenizer, args, logger)
    val_dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn)

    test_dataset = RerankKeywordDataset(test_list, bert_tokenizer, args, logger)
    test_dataloader = DataLoader(test_dataset, batch_size=args.val_batch_size, shuffle=False, collate_fn=test_dataset.collate_fn)

    #创建model
    logger.info('create model')
    model = BertKeywordsClassification(bert_model, bert_config, args)
    model = model.to(args.device)

    #配置optimizer和scheduler
    t_total = len(train_dataloader) * args.epoch_num
    optimizer, _ = get_optimizer_and_scheduler(model, t_total, args.lr, 0)

    if args.type == 'train':
        train(model, train_dataloader, val_dataloader, test_dataloader, optimizer, writer, args, logger, output_path, standard_name_list)

    elif args.type == 'evaluate':
        if args.saved_model_path == '':
            raise Exception('saved model path不能为空')

        # 非k折模型
        if args.k_fold == -1:
            logger.info('loading saved model')
            checkpoint = torch.load(args.saved_model_path, map_location='cpu')
            model.load_state_dict(checkpoint)
            model = model.to(args.device)
            # #生成icd标准词的最新embedding
            evaluate(model, test_dataloader, args, logger, writer, standard_name_list, is_test=True)

        else:
            evaluate_k_fold(model, test_dataloader, args, logger, writer, standard_name_list)
Пример #3
0
def launch(env_params,
           model_params,
           adapt_span_params,
           optim_params,
           data_params,
           trainer_params):
    # ENVIRONMENT (device, distributed, etc.)
    set_up_env(env_params)
    device = env_params['device']
    distributed = env_params['distributed']

    if distributed == False or env_params['rank'] == 0:
        print('model_params:\t', model_params)
        print('optim_params:\t', optim_params)
        print('data_params:\t', data_params)
        print('trainer_params:\t', trainer_params)
        print('adapt_span_params:\t', adapt_span_params)

    # DATA
    train_data, val_data, test_data = get_train_val_test_data(
        data_params=data_params,
        env_params=env_params,
        batch_size=trainer_params['batch_size'],
        device=device)

    # MODEL
    model = TransformerSeq(
        vocab_size=data_params['vocab_size'], **model_params,
        adapt_span_params=adapt_span_params)
    if distributed:
        local_rank = env_params['local_rank']
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank)
    else:
        model = torch.nn.DataParallel(model)
        model = model.to(device)

    # OPTIMIZER AND SCHEDULER
    optimizer, scheduler = get_optimizer_and_scheduler(
        model=model, optim_params=optim_params)

    log_dir = os.path.join('logs', trainer_params['checkpoint_path'].split('/')[-1][:-3])
    # os.makedirs(log_dir, exist_ok=True)

    # create logger
    logger = Logger(logdir=log_dir)

    # resume training from last checkpoint if exists
    iter_init = load_checkpoint(
        trainer_params['checkpoint_path'], model, optimizer, scheduler,
        logger, distributed)

    if trainer_params['full_eval_mode']:
        # evaluate the model on test data
        with torch.no_grad():
            loss_val = full_eval(model, optimizer, scheduler, val_data,
                                 model_params['block_size'],
                                 model_params['hidden_size'])
            loss_test = full_eval(model, optimizer, scheduler, test_data,
                                  model_params['block_size'],
                                  model_params['hidden_size'])
            if distributed:
                # collect results into rank0
                stats = torch.tensor(
                    [loss_val, loss_test]).to(device)
                torch.distributed.reduce(stats, 0)
                if env_params['rank'] == 0:
                    loss_val = stats[0] / env_params['world_size']
                    loss_test = stats[1] / env_params['world_size']
                else:
                    return

            print('val: {:.3f}bpc'.format(loss_val / math.log(2)))
            print('test: {:.3f}bpc'.format(loss_test / math.log(2)))
        return

    # position of current batch
    data_pos = [0] * 2
    # initialize caches for train and valid
    hid_cache = [[
        torch.zeros(
            train_data.size(0),
            layer.attn.attn.get_cache_size(),
            model_params['hidden_size']).to(device)
        for layer in model.module.layers] for _ in range(2)]

    nb_batches_per_iter = trainer_params['nb_batches_per_iter']
    for iter_no in range(iter_init, trainer_params['nb_iter']):
        t_sta = time.time()
        loss_train, data_pos[0], hid_cache[0] = train_iteration(
            model, optimizer, scheduler, train_data, nb_batches_per_iter,
            model_params['block_size'], False, data_pos[0], hid_cache[0],
            trainer_params['batch_split'])
        elapsed = 1000 * (time.time() - t_sta) / nb_batches_per_iter
        with torch.no_grad():
            loss_val, data_pos[1], hid_cache[1] = train_iteration(
                model, optimizer, scheduler, val_data, nb_batches_per_iter,
                model_params['block_size'], True, data_pos[1], hid_cache[1],
                trainer_params['batch_split'])

        if distributed:
            # collect results into rank0
            stats = torch.tensor(
                [loss_train, loss_val]).to(device)
            torch.distributed.reduce(stats, 0)
            if env_params['rank'] == 0:
                loss_train = stats[0] / env_params['world_size']
                loss_val = stats[1] / env_params['world_size']
            else:
                continue

        logger.log_iter(iter_no, nb_batches_per_iter, loss_train,
                        loss_val, elapsed, model)
        save_checkpoint(trainer_params['checkpoint_path'],
                        iter_no, model, optimizer, scheduler, logger)
Пример #4
0
def launch(task_params, env_params, model_params,
           optim_params, data_params, trainer_params):
    main_params = {}

    # print params
    if (env_params['distributed'] == False or
        env_params['rank'] == 0):
        print('env_params:\t', env_params)
        print('model_params:\t', model_params)
        print('optim_params:\t', optim_params)
        print('data_params:\t', data_params)
        print('trainer_params:\t', trainer_params)

    # computation env
    set_up_env(env_params)
    device = env_params['device']

    logger = Logger()

    for task in task_params:
        print (task)

        task_config = task_params[task]
        model_params["block_size"] = task_config["block_size"]
        trainer_params["batch_size"] = task_config["batch_size"]

        print('task_params:\t', task_config)

        # data
        data_path = data_params["data_path"]
        tokenizer = CharBPETokenizer(join(data_path, "tokenizer-vocab.json"),
                                     join(data_path, "tokenizer-merges.txt"),
                                     unk_token="[UNK]")
        train_data, val_data, num_labels = load_glue(tokenizer, task, task_config)

        # model
        pad_idx = tokenizer.token_to_id("[PAD]")
        model = GenDisc(vocab_size=data_params['vocab_size'],
                        batch_size=trainer_params["batch_size"],
                        model_params=model_params, pad_idx=pad_idx)
        model = model.to(device)

        # optimizer, scheduler, logger and resume from checkpoint
        optim_params = task_config["optim_params"]
        optimizer, scheduler = get_optimizer_and_scheduler(
            model=model, optim_params=optim_params)

        # reload checkpoint
        main_params["iter_init"] = load_checkpoint(
            trainer_params['checkpoint_path'], trainer_params['last_iter'],
            model, optimizer, scheduler,
            logger, parallel=False)

        asct = AsctSequenceClassification(task_config, model_params,
                                          model, num_labels)
        asct = asct.to(device)

        # store main params
        main_params["model"] = asct
        main_params["device"] = device
        main_params["optimizer"] = optimizer
        main_params["scheduler"] = scheduler
        main_params["logger"] = logger

        train_glue(train_data, val_data, main_params, trainer_params,
                   env_params, task_config, task)
        return
Пример #5
0
def optimize(trial, args):

    setattr(args, 'hidden_dim',
            int(trial.suggest_categorical('d_model', [128, 256, 512])))
    setattr(args, 'depth',
            int(trial.suggest_discrete_uniform('n_enc', 2, 6, 1)))
    setattr(args, 'n_layers',
            int(trial.suggest_discrete_uniform('n_enc', 1, 3, 1)))
    setattr(args, 'lr', trial.suggest_loguniform('lr', 1e-5, 1e-2))
    setattr(args, 'batch_size',
            int(trial.suggest_categorical('batch_size', [16, 32, 64, 128])))

    setattr(args, 'log_dir',
            os.path.join(args.hyperopt_dir, str(trial._trial_id)))

    torch.manual_seed(0)
    train_logger = create_logger('train', args.log_dir)

    train_logger.info('Arguments are...')
    for arg in vars(args):
        train_logger.info(f'{arg}: {getattr(args, arg)}')

    # construct loader and set device
    train_loader, val_loader = construct_loader(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # build model
    model_parameters = {
        'node_dim': train_loader.dataset.num_node_features,
        'edge_dim': train_loader.dataset.num_edge_features,
        'hidden_dim': args.hidden_dim,
        'depth': args.depth,
        'n_layers': args.n_layers
    }
    model = G2C(**model_parameters).to(device)

    # multi gpu training
    if torch.cuda.device_count() > 1:
        train_logger.info(
            f'Using {torch.cuda.device_count()} GPUs for training...')
        model = torch.nn.DataParallel(model)

    # get optimizer and scheduler
    optimizer, scheduler = get_optimizer_and_scheduler(
        args, model, len(train_loader.dataset))
    loss = torch.nn.MSELoss(reduction='sum')

    # record parameters
    train_logger.info(
        f'\nModel parameters are:\n{dict_to_str(model_parameters)}\n')
    save_yaml_file(os.path.join(args.log_dir, 'model_paramaters.yml'),
                   model_parameters)
    train_logger.info(f'Optimizer parameters are:\n{optimizer}\n')
    train_logger.info(f'Scheduler state dict is:')
    if scheduler:
        for key, value in scheduler.state_dict().items():
            train_logger.info(f'{key}: {value}')
        train_logger.info('')

    best_val_loss = math.inf
    best_epoch = 0

    model.to(device)
    train_logger.info("Starting training...")
    for epoch in range(1, args.n_epochs):
        train_loss = train(model, train_loader, optimizer, loss, device,
                           scheduler, logger if args.verbose else None)
        train_logger.info("Epoch {}: Training Loss {}".format(
            epoch, train_loss))

        val_loss = test(model, val_loader, loss, device, args.log_dir, epoch)
        train_logger.info("Epoch {}: Validation Loss {}".format(
            epoch, val_loss))
        if scheduler and not isinstance(scheduler, NoamLR):
            scheduler.step(val_loss)

        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            torch.save(model.state_dict(),
                       os.path.join(args.log_dir, f'epoch_{epoch}_state_dict'))
    train_logger.info("Best Validation Loss {} on Epoch {}".format(
        best_val_loss, best_epoch))

    train_logger.handlers = []
    return best_val_loss
Пример #6
0
model_parameters = {
    'node_dim': train_loader.dataset.num_node_features,
    'edge_dim': train_loader.dataset.num_edge_features,
    'hidden_dim': args.hidden_dim,
    'depth': args.depth,
    'n_layers': args.n_layers
}
model = G2C(**model_parameters).to(device)

# multi gpu training
if torch.cuda.device_count() > 1:
    logger.info(f'Using {torch.cuda.device_count()} GPUs for training...')
    model = torch.nn.DataParallel(model)

# get optimizer and scheduler
optimizer, scheduler = get_optimizer_and_scheduler(args, model,
                                                   len(train_loader.dataset))

# record parameters
logger.info(f'\nModel parameters are:\n{dict_to_str(model_parameters)}\n')
yaml_file_name = os.path.join(log_dir, 'model_paramaters.yml')
save_yaml_file(yaml_file_name, model_parameters)
logger.info(f'Optimizer parameters are:\n{optimizer}\n')
logger.info(f'Scheduler state dict is:')
if scheduler:
    for key, value in scheduler.state_dict().items():
        logger.info(f'{key}: {value}')
    logger.info('')

loss = torch.nn.MSELoss(reduction='sum')
# alternative loss: MAE
torch.nn.L1Loss(reduction='sum')  # MAE
Пример #7
0
def main():
    parser = ArgumentParser()

    #任务配置
    parser.add_argument('-neg_sample', default='online', type=str)
    parser.add_argument('-neg_num', default=4, type=int)
    parser.add_argument('-device', default=0, type=int)
    parser.add_argument('-output_name', default='', type=str)
    parser.add_argument('-saved_model_path', default='', type=str) #如果是k fold合并模型进行预测,只需设置为对应k_fold模型对应的output path
    parser.add_argument('-type', default='train', type=str)
    parser.add_argument('-k_fold', default=-1, type=int) #如果是-1则说明不采用k折,否则说明采用k折的第几折
    parser.add_argument('-merge_classification', default='vote', type=str) # 个数预测:vote则采用投票法,avg则是平均概率
    parser.add_argument('-merge_sort', default='avg', type=str) # k折融合算法: vote则采用投票法,avg则是平均概率
    parser.add_argument('-generate_candidates', default='no', type=str) # 在预测过程中,是否生成候选词并保存在output_name下,如果是no则不生成,如果不是no则是保存的名字
    parser.add_argument('-k_fold_cache', default='no', type=str) # 在预测k_fold中,是否采用之前的cache,如果存在cache,会在output_name下存在三个list文件, 避免5个模型进行预测
    parser.add_argument('-add_keywords', default='no', type=str) # 这里不添加关键词信息
    parser.add_argument('-seed', default=123456, type=int) # 种子
    parser.add_argument('-loss_type', default='union', type=str) # loss type: class就仅更新分类的参数;sim就仅更新triple loss;union就同时更新二者参数
    parser.add_argument('-hidden_layers', default=12, type=int) # bert的层数
    parser.add_argument('-pretrained_model_path', default='/home/liangming/nas/lm_params/chinese_L-12_H-768_A-12/', type=str) # bert参数地址

    #训练参数
    parser.add_argument('-train_batch_size', default=64, type=int)
    parser.add_argument('-val_batch_size', default=256, type=int)
    parser.add_argument('-lr', default=2e-5, type=float)
    parser.add_argument('-epoch_num', default=20, type=int)

    parser.add_argument('-max_len', default=64, type=int)
    parser.add_argument('-margin', default=1, type=float)
    parser.add_argument('-distance', default='eu', type=str)
    parser.add_argument('-label_nums', default=3, type=int)
    parser.add_argument('-dropout', default=0.3, type=float)
    parser.add_argument('-pool', default='avg', type=str)
    parser.add_argument('-print_loss_step', default=2, type=int)
    parser.add_argument('-hit_list', default=[2, 5, 7, 10], type=list)


    args = parser.parse_args()
    assert args.train_batch_size % args.neg_num == 0, print('batch size应该是neg_num的整数倍')

    #定义时间格式
    DATE_FORMAT = "%Y-%m-%d-%H:%M:%S"
    #定义输出文件夹,如果不存在则创建, 
    if args.output_name == '':
        output_path = os.path.join('./output/mto_output', time.strftime(DATE_FORMAT,time.localtime(time.time())))
    else:
        output_path = os.path.join('./output/mto_output', args.output_name)
        # if os.path.exists(output_path):
            # raise Exception('the output path {} already exists'.format(output_path))
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    #配置tensorboard    
    tensor_board_log_path = os.path.join(output_path, 'tensor_board_log{}'.format('' if args.k_fold == -1 else args.k_fold))
    writer = SummaryWriter(tensor_board_log_path)

    #定义log参数
    logger = Logger(output_path,'main{}'.format('' if args.k_fold == -1 else args.k_fold)).logger
    
    #打印args
    print_args(args, logger)

    #设置seed
    logger.info('set seed to {}'.format(args.seed))
    set_seed(args)

    #读取数据
    logger.info('#' * 20 + 'loading data and model' + '#' * 20)
    data_path = os.path.join(project_path, 'data')
    train_list, val_list, test_list, code_to_name, name_to_code, standard_name_list = read_data(data_path, logger, args)

    #load model
    # pretrained_model_path = '/home/liangming/nas/lm_params/chinese_L-12_H-768_A-12/'
    pretrained_model_path = args.pretrained_model_path
    bert_config, bert_tokenizer, bert_model = get_pretrained_model(pretrained_model_path, logger, args)

    #获取dataset
    logger.info('create dataloader')
    train_dataset = TripleDataset(train_list, standard_name_list, code_to_name, name_to_code, bert_tokenizer, args, logger)
    train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=False, collate_fn=train_dataset.collate_fn)

    val_dataset = TermDataset(val_list, bert_tokenizer, args, False)
    val_dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn)

    test_dataset = TermDataset(test_list, bert_tokenizer, args, False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.val_batch_size, shuffle=False, collate_fn=test_dataset.collate_fn)

    icd_dataset = TermDataset(standard_name_list, bert_tokenizer, args, True)
    icd_dataloader = DataLoader(icd_dataset, batch_size=args.val_batch_size, shuffle=False, collate_fn=icd_dataset.collate_fn)

    train_term_dataset = TermDataset(train_list, bert_tokenizer, args, False)
    train_term_dataloader = DataLoader(train_term_dataset, batch_size=args.val_batch_size, shuffle=False, collate_fn=train_term_dataset.collate_fn)


    #创建model
    logger.info('create model')
    model = SiameseClassificationModel(bert_model, bert_config, args)
    model = model.to(args.device)

    #配置optimizer和scheduler
    t_total = len(train_dataloader) * args.epoch_num
    optimizer, _ = get_optimizer_and_scheduler(model, t_total, args.lr, 0)

    if args.type == 'train':
        train(model, train_dataloader, val_dataloader, test_dataloader, icd_dataloader, train_term_dataloader, optimizer, writer, args, logger, \
                standard_name_list, train_list, output_path, bert_tokenizer, name_to_code, code_to_name)

    elif args.type == 'evaluate':
        if args.saved_model_path == '':
            raise Exception('saved model path不能为空')

        # 非k折模型
        if args.k_fold == -1:
            logger.info('loading saved model')
            checkpoint = torch.load(args.saved_model_path, map_location='cpu')
            model.load_state_dict(checkpoint)
            model = model.to(args.device)
            # #生成icd标准词的最新embedding
            logger.info('generate icd embedding')
            icd_embedding = get_model_embedding(model, icd_dataloader, True)
            start_time = time.time()
            evaluate(model, test_dataloader, icd_embedding, args, logger, writer, standard_name_list, True)
            end_time = time.time()
            logger.info('total length is {}, predict time is {}, per mention time is {}'.format(len(test_list), end_time - start_time, (end_time - start_time) / len(test_list)))

        else:
            evaluate_k_fold(model, test_dataloader, icd_dataloader, args, logger, writer, standard_name_list)
Пример #8
0
def launch(env_params, model_params, optim_params, data_params,
           trainer_params):
    main_params = {}

    # print params
    print('env_params:\t', env_params)
    print('model_params:\t', model_params)
    print('optim_params:\t', optim_params)
    print('data_params:\t', data_params)
    print('trainer_params:\t', trainer_params)

    # computation env
    set_up_env(env_params)
    device = env_params['device']

    # data
    weights_embed, train_data, val_data, tokenizer = (get_data_imdb(
        trainer_params['batch_size'],
        device,
        entry_size=model_params["block_size"],
        vanilla=model_params["vanilla"],
        load=False))

    #tokenizer = get_tokenizer(**data_params)

    # model
    #if model_params["vanilla"]:
    #    model = Vanilla()
    model = AsctImdbClassification(batch_size=trainer_params["batch_size"],
                                   weights_embed=weights_embed,
                                   pad_idx=tokenizer["[PAD]"],
                                   model_params=model_params,
                                   vanilla=model_params["vanilla"])
    model = model.to(device)
    #model = torch.nn.DataParallel(model)

    # optimizer, scheduler, logger and resume from checkpoint
    optimizer, scheduler = get_optimizer_and_scheduler(
        model=model, optim_params=optim_params)
    #logger = Logger()
    #main_params["iter_init"] = load_checkpoint(
    #    trainer_params,
    #    trainer_params['last_iter'], model,
    #    optimizer, scheduler,
    #    logger, parallel=True)

    # store main params
    #main_params["model"] = model
    #main_params["device"] = device
    #main_params["optimizer"] = optimizer
    #main_params["scheduler"] = scheduler
    #main_params["logger"] = logger

    iter_start = trainer_params["last_iter"]
    if iter_start:
        load_checkpoint(trainer_params, iter_start, model, optimizer, False)

    # iter
    if model_params["vanilla"]:
        train_imdb_vanilla(model_params, trainer_params, model, optimizer,
                           train_data, val_data)
    else:
        train_imdb(model_params, trainer_params, model, optimizer, train_data,
                   val_data)