예제 #1
0
def load_imagenet_data(train_dir, test_dir, aug_batch_size, no_aug_batch_size):
    """
    Loads the ImageNet dataset into train and test data loaders. There are two loader for the train dataset: one with
    augmentations and one without them, for testing purposes.
    :param train_dir: The name of the directory where the train dataset is located.
    :param test_dir: The name of the directory where the test dataset is located.
    :param aug_batch_size: The batch size to use when loading data with augmentations (training phase).
    :param no_aug_batch_size: The batch size to use when loading data without augmentations (testing phase).
    :return: Data loaders for the train and test datasets of ImageNet (2 train, 1 test).
    """
    aug_transform, no_aug_transform = get_transforms()

    # Create a data loader for the train dataset with augmentations.
    train_dataset = ImageNetMini(root_dir=train_dir, transform=aug_transform, augment=True)
    train_loader = create_data_loader(train_dataset, is_train=True, batch_size=aug_batch_size)

    # Create a data loader for the train dataset without augmentations.
    raw_train_dataset = ImageNetMini(root_dir=train_dir, transform=no_aug_transform, augment=False)
    raw_train_loader = create_data_loader(raw_train_dataset, is_train=True, batch_size=no_aug_batch_size)

    # Create a data loader for the test dataset (without augmentations).
    test_dataset = ImageNetMini(root_dir=test_dir, transform=no_aug_transform, augment=False)
    test_loader = create_data_loader(test_dataset, is_train=False, batch_size=no_aug_batch_size)

    return train_loader, raw_train_loader, test_loader
예제 #2
0
def train_and_evaluate(config):
    token_makers = create_by_factory(TokenMakersFactory, config.token)
    tokenizers = token_makers["tokenizers"]
    del token_makers["tokenizers"]

    config.data_reader.tokenizers = tokenizers
    if nsml.IS_ON_NSML:
        config.data_reader.train_file_path = os.path.join(
            DATASET_PATH, "train", "train_data",
            config.data_reader.train_file_path)
        config.data_reader.valid_file_path = os.path.join(
            DATASET_PATH, "train", "train_data",
            config.data_reader.valid_file_path)

    data_reader = create_by_factory(DataReaderFactory, config.data_reader)
    datas, helpers = data_reader.read()

    # Vocab & Indexing
    text_handler = TextHandler(token_makers, lazy_indexing=True)
    texts = data_reader.filter_texts(datas)

    token_counters = text_handler.make_token_counters(texts)
    text_handler.build_vocabs(token_counters)
    text_handler.index(datas, data_reader.text_columns)

    # Iterator
    datasets = data_reader.convert_to_dataset(datas, helpers=helpers)
    train_loader = create_data_loader(datasets["train"],
                                      batch_size=config.iterator.batch_size,
                                      shuffle=True,
                                      cuda_device_id=device)
    valid_loader = create_data_loader(datasets["valid"],
                                      batch_size=config.iterator.batch_size,
                                      shuffle=False,
                                      cuda_device_id=device)

    # Model & Optimizer
    model = create_model(token_makers,
                         ModelFactory,
                         config.model,
                         device,
                         helpers=helpers)
    model_parameters = [
        param for param in model.parameters() if param.requires_grad
    ]

    optimizer = get_optimizer_by_name("adam")(model_parameters)

    if IS_ON_NSML:
        bind_nsml(model, optimizer=optimizer)

    # Trainer
    trainer_config = vars(config.trainer)
    trainer_config["model"] = model
    trainer = Trainer(**trainer_config)
    trainer.train_and_evaluate(train_loader, valid_loader, optimizer)
예제 #3
0
def load_cifar10_data(batch_size):
    """
    Downloads the CIFAR10 dataset and loads it into train and test data loaders.
    :param batch_size: The batch size to use when loading data from the dataset.
    :return: Data loaders for the train and test datasets of CIFAR10.
    """
    # Download the train and test datasets.
    train_data, test_data = download_dataset()

    # Create data loaders for the train and test datasets.
    train_loader = create_data_loader(train_data,
                                      is_train=True,
                                      batch_size=batch_size)
    test_loader = create_data_loader(test_data,
                                     is_train=False,
                                     batch_size=batch_size)
    return train_loader, test_loader
예제 #4
0
def infer(args):
    paddle.set_device(args.device)
    set_seed(args.seed)

    model = UnifiedTransformerLMHeadModel.from_pretrained(
        args.model_name_or_path)
    tokenizer = UnifiedTransformerTokenizer.from_pretrained(
        args.model_name_or_path)

    test_ds = load_dataset('duconv', splits='test_1')
    test_ds, test_data_loader = create_data_loader(test_ds, tokenizer, args,
                                                   'test')

    model.eval()
    total_time = 0.0
    start_time = time.time()
    pred_responses = []
    for step, inputs in enumerate(test_data_loader, 1):
        input_ids, token_type_ids, position_ids, attention_mask, seq_len = inputs
        output = model.generate(input_ids=input_ids,
                                token_type_ids=token_type_ids,
                                position_ids=position_ids,
                                attention_mask=attention_mask,
                                seq_len=seq_len,
                                max_length=args.max_dec_len,
                                min_length=args.min_dec_len,
                                decode_strategy=args.decode_strategy,
                                temperature=args.temperature,
                                top_k=args.top_k,
                                top_p=args.top_p,
                                num_beams=args.num_beams,
                                length_penalty=args.length_penalty,
                                early_stopping=args.early_stopping,
                                num_return_sequences=args.num_return_sequences,
                                use_fp16_decoding=args.use_fp16_decoding,
                                use_faster=args.faster)

        total_time += (time.time() - start_time)
        if step % args.logging_steps == 0:
            print('step %d - %.3fs/step' %
                  (step, total_time / args.logging_steps))
            total_time = 0.0

        ids, scores = output
        results = select_response(ids, scores, tokenizer, args.max_dec_len,
                                  args.num_return_sequences)
        pred_responses.extend(results)

        start_time = time.time()

    with open(args.output_path, 'w', encoding='utf-8') as fout:
        for response in pred_responses:
            fout.write(response + '\n')
    print('\nSave inference result into: %s' % args.output_path)

    target_responses = [example['response'] for example in test_ds]
    calc_bleu_and_distinct(pred_responses, target_responses)
예제 #5
0
def preprocess(options):
    # parse the input args
    dataset = options['dataset']
    model_path = options['model_path']
    batch_size = options['batch_size']
    DTYPE = torch.FloatTensor
    if options['cuda']:
        DTYPE = torch.cuda.FloatTensor

    # prepare the paths for storing models
    model_path = os.path.join(model_path, "tfn.pt")
    print("Temp location for saving model: {}".format(model_path))

    # define fields
    text_field = 'CMU_MOSI_TimestampedWordVectors_1.1'
    visual_field = 'CMU_MOSI_VisualFacet_4.1'
    acoustic_field = 'CMU_MOSI_COVAREP'
    label_field = 'CMU_MOSI_Opinion_Labels'

    # DEBUG ONLY
    recalc = not (os.path.exists('vars/dump') and os.path.isfile('vars/dump'))

    if recalc:
        # prepare the datasets
        print("Currently using {} dataset.".format(dataset))
        DATASET = utils.download()
        dataset = utils.load(visual_field, acoustic_field, text_field)
        utils.align(text_field, dataset)
        utils.annotate(dataset, label_field)
        splits = utils.get_splits(DATASET)
        if not os.path.exists('./vars'):
            os.makedirs('./vars')
        f = open('./vars/dump', 'wb+')
        pickle.dump([splits, dataset], f)
        f.close()
    else:
        f = open('./vars/dump', 'rb')
        splits, dataset = pickle.load(f)
        f.close()

    input_dims = utils.get_dims_from_dataset(dataset, text_field,
                                             acoustic_field, visual_field)
    train, dev, test = utils.split(splits, dataset, label_field, visual_field,
                                   acoustic_field, text_field, batch_size)
    train_loader, dev_loader, test_loader = utils.create_data_loader(
        train, dev, test, batch_size, DTYPE)
    return train_loader, dev_loader, test_loader, input_dims
    VGG = VGG.to(device)
    
    model.train()

    
    ### create dataset
    train_dataset = datasets.MultiFramesDataset(opts, "train")

    
    ### start training
    while model.epoch < opts.epoch_max:

        model.epoch += 1

        ### re-generate train data loader for every epoch
        data_loader = utils.create_data_loader(train_dataset, opts, "train")

        ### update learning rate
        current_lr = utils.learning_rate_decay(opts, model.epoch)

        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr
        
        ## submodule
        flow_warping = Resample2d().to(device)
        downsampler = nn.AvgPool2d((2, 2), stride=2).to(device)


        ### criterion and loss recorder
        if opts.loss == 'L2':
            criterion = nn.MSELoss(size_average=True)
예제 #7
0
def train(args):
    paddle.set_device(args.device)
    world_size = dist.get_world_size()
    if world_size > 1:
        dist.init_parallel_env()

    set_seed(args.seed)

    model = UnifiedTransformerLMHeadModel.from_pretrained(
        args.model_name_or_path)
    tokenizer = UnifiedTransformerTokenizer.from_pretrained(
        args.model_name_or_path)

    if world_size > 1:
        model = paddle.DataParallel(model)

    train_ds, dev_ds = load_dataset('duconv', splits=('train', 'dev'))
    train_ds, train_data_loader = create_data_loader(train_ds, tokenizer, args,
                                                     'train')
    dev_ds, dev_data_loader = create_data_loader(dev_ds, tokenizer, args,
                                                 'dev')

    lr_scheduler = NoamDecay(1 / (args.warmup_steps * (args.lr**2)),
                             args.warmup_steps)
    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = AdamW(learning_rate=lr_scheduler,
                      parameters=model.parameters(),
                      weight_decay=args.weight_decay,
                      apply_decay_param_fun=lambda x: x in decay_params,
                      grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm))

    step = 0
    total_time = 0.0
    best_ppl = 1e9
    for epoch in range(args.epochs):
        print('\nEpoch %d/%d' % (epoch + 1, args.epochs))
        batch_start_time = time.time()
        for inputs in train_data_loader:
            step += 1
            labels = inputs[-1]

            logits = model(*inputs[:-1])
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

            total_time += (time.time() - batch_start_time)
            if step % args.logging_steps == 0:
                ppl = paddle.exp(loss)
                print(
                    'step %d - loss: %.4f - ppl: %.4f - lr: %.7f - %.3fs/step'
                    % (step, loss, ppl, optimizer.get_lr(),
                       total_time / args.logging_steps))
                total_time = 0.0
            if step % args.save_steps == 0:
                ppl = evaluation(model, dev_data_loader)
                if dist.get_rank() == 0:
                    save_ckpt(model, tokenizer, args.save_dir, step)
                    if ppl < best_ppl:
                        best_ppl = ppl
                        save_ckpt(model, tokenizer, args.save_dir, 'best')
                        print('Saved step {} as best model.\n'.format(step))
            batch_start_time = time.time()
    print('\nTraining completed.')
예제 #8
0
            successful_download = True
            print('Sucessfully downloaded after {} retries.'.format(retries))

        except:
            retries = retries + 1
            random_sleep = random.randint(1, 30)
            print('Retry #{}.  Sleeping for {} seconds'.format(
                retries, random_sleep))
            time.sleep(random_sleep)

    if not tokenizer or not model or not config:
        print('Not properly initialized...')

    ###### CREATE DATA LOADERS
    train_data_loader, df_train = create_data_loader(args.train_data,
                                                     tokenizer,
                                                     args.max_seq_len,
                                                     args.train_batch_size)
    val_data_loader, df_val = create_data_loader(args.validation_data,
                                                 tokenizer, args.max_seq_len,
                                                 args.validation_batch_size)

    logger.debug("Processes {}/{} ({:.0f}%) of train data".format(
        len(train_data_loader.sampler), len(train_data_loader.dataset), 100. *
        len(train_data_loader.sampler) / len(train_data_loader.dataset)))

    logger.debug("Processes {}/{} ({:.0f}%) of test data".format(
        len(val_data_loader.sampler), len(val_data_loader.dataset),
        100. * len(val_data_loader.sampler) / len(val_data_loader.dataset)))

    # model_dir = os.environ['SM_MODEL_DIR']
    print('model_dir: {}'.format(args.model_dir))
예제 #9
0
def re_train_and_evaluate(config):
    NSML_SESSEION = 'team_6/19_tcls_qa/258'  # NOTE: need to hard code
    NSML_CHECKPOINT = '1'  # NOTE: nghhhhed to hard code

    assert NSML_CHECKPOINT is not None, "You must insert NSML Session's checkpoint for submit"
    assert NSML_SESSEION is not None, "You must insert NSML Session's name for submit"

    token_makers = create_by_factory(TokenMakersFactory, config.token)
    tokenizers = token_makers["tokenizers"]
    del token_makers["tokenizers"]

    config.data_reader.tokenizers = tokenizers
    if nsml.IS_ON_NSML:
        config.data_reader.train_file_path = os.path.join(
            DATASET_PATH, "train", "train_data",
            config.data_reader.train_file_path)
        config.data_reader.valid_file_path = os.path.join(
            DATASET_PATH, "train", "train_data",
            config.data_reader.valid_file_path)

    data_reader = create_by_factory(DataReaderFactory, config.data_reader)
    datas, helpers = data_reader.read()

    # Vocab & Indexing
    text_handler = TextHandler(token_makers, lazy_indexing=True)
    texts = data_reader.filter_texts(datas)

    token_counters = text_handler.make_token_counters(texts)
    text_handler.build_vocabs(token_counters)
    text_handler.index(datas, data_reader.text_columns)

    def bind_load_vocabs(config, token_makers):
        CHECKPOINT_FNAME = "checkpoint.bin"

        def load(dir_path):
            checkpoint_path = os.path.join(dir_path, CHECKPOINT_FNAME)
            checkpoint = torch.load(checkpoint_path)

            vocabs = {}
            token_config = config.token
            for token_name in token_config.names:
                token = getattr(token_config, token_name, {})
                vocab_config = getattr(token, "vocab", {})

                texts = checkpoint["vocab_texts"][token_name]
                if type(vocab_config) != dict:
                    vocab_config = vars(vocab_config)
                vocabs[token_name] = Vocab(token_name,
                                           **vocab_config).from_texts(texts)

            for token_name, token_maker in token_makers.items():
                token_maker.set_vocab(vocabs[token_name])
            return token_makers

        nsml.bind(load=load)

    bind_load_vocabs(config, token_makers)
    nsml.load(checkpoint=NSML_CHECKPOINT, session=NSML_SESSEION)

    # Raw to Tensor Function
    text_handler = TextHandler(token_makers, lazy_indexing=False)
    raw_to_tensor_fn = text_handler.raw_to_tensor_fn(
        data_reader,
        cuda_device=device,
    )

    # Iterator
    datasets = data_reader.convert_to_dataset(datas, helpers=helpers)
    train_loader = create_data_loader(datasets["train"],
                                      batch_size=config.iterator.batch_size,
                                      shuffle=True,
                                      cuda_device_id=device)
    valid_loader = create_data_loader(datasets["valid"],
                                      batch_size=config.iterator.batch_size,
                                      shuffle=False,
                                      cuda_device_id=device)

    # Model & Optimizer
    model = create_model(token_makers,
                         ModelFactory,
                         config.model,
                         device,
                         helpers=helpers)
    model_parameters = [
        param for param in model.parameters() if param.requires_grad
    ]

    optimizer = get_optimizer_by_name("adam")(model_parameters)

    def bind_load_model(config, model, **kwargs):
        CHECKPOINT_FNAME = "checkpoint.bin"

        def load(dir_path):
            checkpoint_path = os.path.join(dir_path, CHECKPOINT_FNAME)
            checkpoint = torch.load(checkpoint_path)

            model.load_state_dict(checkpoint["weights"])
            model.config = checkpoint["config"]
            model.metrics = checkpoint["metrics"]
            model.init_params = checkpoint["init_params"],
            model.predict_helper = checkpoint["predict_helper"],
            model.train_counter = TrainCounter(display_unit=1000)
            # model.vocabs = load_vocabs(checkpoint)

            if "optimizer" in kwargs:
                kwargs["optimizer"].load_state_dict(checkpoint["optimizer"][0])

            print(f"Model reload checkpoints...! {checkpoint_path}")

        nsml.bind(load=load)

    bind_load_model(config, model, optimizer=optimizer)
    nsml.load(checkpoint=NSML_CHECKPOINT, session=NSML_SESSEION)

    if IS_ON_NSML:
        bind_nsml(model, optimizer=optimizer)

    # Trainer
    trainer_config = vars(config.trainer)
    trainer_config["model"] = model
    trainer = Trainer(**trainer_config)
    trainer.train_and_evaluate(train_loader, valid_loader, optimizer)
예제 #10
0
    epochs = args.epochs
    seed = args.seed
    max_len = args.max_len
    class_names = ['negative', 'neutral', 'positive']
    train_path = f'{args.data_folder}/train.csv'
    validation_path = f'{args.data_folder}/validation.csv'
    test_path = f'{args.data_folder}/test.csv'
    output_folder= args.output_folder
    

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    tokenizer = BertTokenizer.from_pretrained(model_name)
    # CREATE DATA LOADERS
    train_data_loader, df_train = create_data_loader(train_path, tokenizer, max_len, batch_size)
    val_data_loader, df_val = create_data_loader(validation_path, tokenizer, max_len, batch_size)
    
    # INSTANTIATE MODEL
    model = SentimentClassifier(len(class_names), model_name)
    model = model.to(device)
    
    train_model(model,
                train_data_loader,
                df_train,
                val_data_loader, 
                df_val,
                epochs,
                learning_rate,
                device,
                output_folder)
예제 #11
0
def main(cfg):
    print(f'Training stated: {cfg.model}')

    dataset_train, dataset_val, _ = create_dataset(cfg)
    data_loader_train = create_data_loader(dataset_train,
                                           cfg.batch_size,
                                           shuffle=True,
                                           num_workers=3)
    data_loader_dev = create_data_loader(dataset_val,
                                         cfg.batch_size,
                                         shuffle=False,
                                         num_workers=3)
    print(f'Data loader: {len(data_loader_train)}, {len(data_loader_dev)}')

    # weighted loss
    if cfg.use_class_weight:
        classes_weights = compute_class_weight('balanced',
                                               np.unique(dataset_train.titles),
                                               dataset_train.titles)
        classes_weights = to_device(classes_weights).float()
        print(f'Class weight: yes')
    else:
        classes_weights = None
        print(f'Class weight: no')

    # create model
    model = create_model(cfg, dataset_train, create_W_emb=True)

    model_parameters = get_trainable_parameters(model.parameters())
    optimizer = torch.optim.Adam(model_parameters,
                                 cfg.learning_rate,
                                 weight_decay=cfg.weight_decay,
                                 amsgrad=True)

    criterion = BiosLoss(cfg, classes_weights)

    def update_function(engine, batch):
        model.train()
        optimizer.zero_grad()

        inputs, targets = to_device(batch)

        logits = model(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_parameters, cfg.max_grad_norm)
        optimizer.step()

        return loss.item()

    def inference_function(engine, batch):
        model.eval()
        with torch.no_grad():
            inputs, targets = to_device(batch)

            logits = model(inputs)

            return logits, targets

    trainer = Engine(update_function)
    evaluator = Engine(inference_function)

    metrics = [
        ('loss', Loss(criterion)),
        ('accuracy', Accuracy()),
    ]
    for name, metric in metrics:
        metric.attach(evaluator, name)

    best_val_loss = np.inf
    nb_epoch_no_improvement = 0

    @trainer.on(Events.EPOCH_COMPLETED)
    def loss_step(engine):
        criterion.epoch_complete()

    @trainer.on(Events.EPOCH_COMPLETED)
    def eval_model(engine):
        nonlocal best_val_loss, nb_epoch_no_improvement

        def log_progress(mode, metrics_values):
            metrics_str = ', '.join([
                f'{metric_name} {metrics_values[metric_name]:.3f}'
                for metric_name, _ in metrics
            ])
            print(f'{mode}: {metrics_str}', end=' | ')

        # evaluator.run(data_loader_train)
        # metrics_train = evaluator.state.metrics.copy()

        evaluator.run(data_loader_dev)
        metrics_dev = evaluator.state.metrics.copy()

        print(f'Epoch {engine.state.epoch:>3}', end=' | ')
        # log_progress('train', metrics_train)
        log_progress('dev', metrics_dev)

        if best_val_loss > metrics_dev[
                'loss'] or engine.state.epoch <= cfg.min_epochs:
            best_val_loss = metrics_dev['loss']
            nb_epoch_no_improvement = 0

            save_weights(model, os.path.join(CACHE_DIR,
                                             f'model_{cfg.model}.pt'))
            print('Model saved', end=' ')
        else:
            nb_epoch_no_improvement += 1

        if cfg.early_stopping_patience != 0 and nb_epoch_no_improvement > cfg.early_stopping_patience:
            trainer.terminate()

        print()

    trainer.run(data_loader_train, max_epochs=cfg.nb_epochs)

    print(f'Training finished')
예제 #12
0
def run():

    # Load data and little exploration
    data = utils.load_data(config.DATA_PATH)
    utils.data_exploration(data)  # Data exploration just print

    # print(data.head())
    # print(data.polarity.values)
    # Create dataLoader

    data, _ = model_selection.train_test_split(data,
                                               test_size=0.995,
                                               random_state=42,
                                               stratify=data.polarity.values)

    train, valid = model_selection.train_test_split(
        data, test_size=0.5, random_state=42, stratify=data.polarity.values)

    train_data_loader = utils.create_data_loader(train)
    valid_data_loader = utils.create_data_loader(valid, is_train=False)

    #S

    # Build Model and send it to device
    model = BuildModel()
    model = model.to(config.DEVICE)

    # Set weight decay to 0 fro no_decay params
    # Set weights decays to 0.01 for others
    param_optimizer = list(model.named_parameters())
    no_decay = ["biais", "LayerNorm.biais", "LayerNorm.weight"]

    optimizer_parameters = [{
        'params':
        [tensor for name, tensor in param_optimizer if name in no_decay],
        'weight_decay':
        0
    }, {
        'params':
        [tensor for name, tensor in param_optimizer if name not in no_decay],
        'weight_decay':
        config.WEIGHT_DECAY
    }]

    # This is the overall number of trainings that will be performed
    num_training_steps = int(
        (train.shape[0] / config.TRAIN_BATCH_SIZE) * config.EPOCHS)

    optimizer = AdamW(optimizer_parameters, lr=3 * 10 - 5)  # Arbitrary set

    # Scheduler to performs adaptatite LR regarding epochs number
    # Scheduler_with_warm_up consiste a augmenter le LR dans les premiers
    # warm_up steps afin de converger plus vite dans les debuts
    scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_training_steps=num_training_steps,
        num_warmup_steps=4)

    # model = nn.DataParallel(model) # if multiples GPU

    best_accuracy = 0
    best_model_state = None

    for epoch in range(config.EPOCHS):
        engine.train_fn(train_data_loader, model, optimizer, scheduler)
        outputs, targets = engine.eval_fn(valid_data_loader, model)
        outputs = np.where(np.array(outputs) > 0.5, 1, 0)
        accuracy = metrics.accuracy_score(np.array(targets), outputs)
        print(f"Accuracy, Epoch {epoch} : {accuracy}")
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = model.state_dict

    print("Best accuracy : {best_accuracy}")

    print("Saving Model...")
    torch.save(best_model_state, config.MODEL_PATH)
예제 #13
0
def run():
    df = pd.read_csv("inputs/reviews.csv")
    df["sentiment"] = df.score.apply(rating_to_sentiment)
    df_train, df_rem = train_test_split(df,
                                        test_size=0.1,
                                        random_state=config.RANDOM_SEED)
    df_val, df_test = train_test_split(df_rem,
                                       test_size=0.5,
                                       random_state=config.RANDOM_SEED)
    train_data_loader = create_data_loader(df_train, config.TOKENIZER,
                                           config.MAX_LEN, config.BATCH_SIZE)
    val_data_loader = create_data_loader(df_val, config.TOKENIZER,
                                         config.MAX_LEN, config.BATCH_SIZE)
    test_data_loader = create_data_loader(df_test, config.TOKENIZER,
                                          config.MAX_LEN, config.BATCH_SIZE)

    # data = next(iter(val_data_loader))
    # input_ids = data["input_ids"].to(config.DEVICE)
    # attention_mask = data["attention_mask"].to(config.DEVICE)
    # bert_model = BertModel.from_pretrained(config.BERT_NAME)

    model = SentimentClassifier(num_classes=len(class_labels))
    if config.LOAD_MODEL == True:
        model.load_state_dict(torch.load("best_model_state.bin"))
    model = model.to(config.DEVICE)

    optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
    total_steps = len(train_data_loader) * config.EPOCHS
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,
                                                num_training_steps=total_steps)
    loss_fn = nn.CrossEntropyLoss().to(config.DEVICE)

    history = defaultdict(list)
    best_accuracy = 0

    for epoch in range(config.EPOCHS):
        print(f"Epoch {epoch + 1}/{config.EPOCHS}")
        print("-" * 10)

        train_acc, train_loss = train_fn(
            model,
            train_data_loader,
            loss_fn,
            optimizer,
            config.DEVICE,
            scheduler,
            len(df_train),
        )

        print(f"Train loss {train_loss} accuracy {train_acc}")

        val_acc, val_loss = eval_fn(model, val_data_loader, loss_fn,
                                    config.DEVICE, len(df_val))

        print(f"Val   loss {val_loss} accuracy {val_acc}")
        print()

        history["train_acc"].append(train_acc)
        history["train_loss"].append(train_loss)
        history["val_acc"].append(val_acc)
        history["val_loss"].append(val_loss)

        if val_acc > best_accuracy:
            torch.save(model.state_dict(), "best_model_state.bin")
            best_accuracy = val_acc