Ejemplo n.º 1
0
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                inputs_embeds=None, labels=None):
        """"""
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]
        if labels is not None:
            loss_fct = {
                None: CrossEntropyLoss(),
                "lsl": LabelSmoothingLoss(classes=self.num_labels, smoothing=.2),
                "distrib": CrossEntropyLoss(weight=self.class_weights.to(device) if self.class_weights else None),
                "batch": CrossEntropyLoss(weight=self.get_weights(labels, self.num_labels).to(device))
            }[self.loss_fct]
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)
Ejemplo n.º 2
0
    def __init__(self,
                 embed_size,
                 hidden_size,
                 vocab,
                 dropout_rate=0.2,
                 input_feed=True,
                 label_smoothing=0.):
        super(NMT, self).__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate
        self.vocab = vocab
        self.input_feed = input_feed

        # initialize neural network layers...

        self.src_embed = nn.Embedding(len(vocab.src),
                                      embed_size,
                                      padding_idx=vocab.src['<pad>'])
        self.tgt_embed = nn.Embedding(len(vocab.tgt),
                                      embed_size,
                                      padding_idx=vocab.tgt['<pad>'])

        self.encoder_lstm = nn.LSTM(embed_size,
                                    hidden_size,
                                    bidirectional=True)
        decoder_lstm_input = embed_size + hidden_size if self.input_feed else embed_size
        self.decoder_lstm = nn.LSTMCell(decoder_lstm_input, hidden_size)

        # attention: dot product attention
        # project source encoding to decoder rnn's state space
        self.att_src_linear = nn.Linear(hidden_size * 2,
                                        hidden_size,
                                        bias=False)

        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the `attentional vector` in (Luong et al., 2015)
        self.att_vec_linear = nn.Linear(hidden_size * 2 + hidden_size,
                                        hidden_size,
                                        bias=False)

        # prediction layer of the target vocabulary
        self.readout = nn.Linear(hidden_size, len(vocab.tgt), bias=False)

        # dropout layer
        self.dropout = nn.Dropout(self.dropout_rate)

        # initialize the decoder's state and cells with encoder hidden states
        self.decoder_cell_init = nn.Linear(hidden_size * 2, hidden_size)

        self.label_smoothing = label_smoothing
        if label_smoothing > 0.:
            self.label_smoothing_loss = LabelSmoothingLoss(
                label_smoothing,
                tgt_vocab_size=len(vocab.tgt),
                padding_idx=vocab.tgt['<pad>'])
Ejemplo n.º 3
0
def fetch_loss(args):
    if args.loss_fn == "SCE":
        return SCELoss()
    elif args.loss_fn == "CE":
        return nn.CrossEntropyLoss()
    elif args.loss_fn == "Label":
        return LabelSmoothingLoss(classes=args.num_classes,
                                  smoothing=args.label_smoothing_ratio)
    elif args.loss_fn == "BTLL":
        return bi_tempered_logistic_loss(
            t1=0.2, t2=1.0)  # Large parameter --> t1=0.2, t2=1.0
    else:
        NotImplementedError
Ejemplo n.º 4
0
def fetch_multiloss(args):
    loss_ls = {}
    for loss_name in args.multi_loss_list:
        if loss_name == "SCE":
            loss_ls["SCE"] = SCELoss()
        elif loss_name == "CE":
            loss_ls["CE"] = nn.CrossEntropyLoss()
        elif loss_name == "Label":
            loss_ls["Label"] = LabelSmoothingLoss(
                classes=args.num_classes, smoothing=args.label_smoothing_ratio)
        elif loss_name == "BTLL":
            loss_ls["Label"] = bi_tempered_logistic_loss(
                t1=0.2, t2=1.0)  # Large parameter --> t1=0.2, t2=1.0

    return loss_ls
Ejemplo n.º 5
0
def main():

    # fix seed for train reproduction
    seed_everything(args.SEED)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("\n device", device)

    # TODO dataset loading
    train_df = pd.read_csv('/DATA/trainset-for_user.csv', header=None)
    train_df = train_df.dropna().reset_index(drop=True)
    test_df = pd.read_csv('/DATA/testset-for_user.csv', header=None)
    print('train_df shape : ', train_df.shape)

    train_df = create_str_feature(train_df)
    test_df = create_str_feature(test_df)

    train_df['patient_label'] = train_df['patient'] + '_' + train_df['label']
    train_df['count'] = train_df['patient_label'].map(
        train_df['patient_label'].value_counts())

    print(train_df.head())
    print(train_df.isnull().sum())
    from sklearn.model_selection import train_test_split

    train_df['image_path'] = [
        os.path.join('/DATA', train_df['patient'][i], train_df['image'][i])
        for i in range(train_df.shape[0])
    ]
    labels = train_df['label'].map({
        'Wake': 0,
        'N1': 1,
        'N2': 2,
        'N3': 3,
        'REM': 4
    }).values
    str_train_df = train_df[['time', 'user_count', 'user_max',
                             'user_min']].values
    str_test_df = test_df[['time', 'user_count', 'user_max',
                           'user_min']].values

    print('meta max value: ', str_train_df.max(), str_test_df.max(),
          'meta shape: ', str_train_df.shape, str_test_df.shape)

    skf_labels = train_df['patient'] + '_' + train_df['label']

    unique_idx = train_df[train_df['count'] == 1].index
    non_unique_idx = train_df[train_df['count'] > 1].index
    trn_idx, val_idx, trn_labels, val_labels = train_test_split(
        non_unique_idx,
        labels[non_unique_idx],
        test_size=0.05,
        random_state=0,
        shuffle=True,
        stratify=skf_labels[non_unique_idx])

    # valid set define
    trn_image_paths = train_df.loc[trn_idx, 'image_path'].values
    val_image_paths = train_df.loc[val_idx, 'image_path'].values

    # struture data define
    trn_str_data = str_train_df[trn_idx, :]
    val_str_data = str_train_df[val_idx, :]

    print('\n')
    print('8:2 train, valid split : ', len(trn_image_paths), len(trn_labels),
          len(val_image_paths), len(val_labels), trn_str_data.shape,
          val_str_data.shape)
    print('\n')
    print(trn_image_paths[:5], trn_labels[:5])
    print(val_image_paths[:5], val_labels[:5])

    valid_transforms = create_val_transforms(args, args.input_size)
    if args.DEBUG:
        valid_dataset = SleepDataset(args,
                                     val_image_paths[:100],
                                     val_str_data,
                                     val_labels[:100],
                                     valid_transforms,
                                     is_test=False)
    else:
        valid_dataset = SleepDataset(args,
                                     val_image_paths,
                                     val_str_data,
                                     val_labels,
                                     valid_transforms,
                                     is_test=False)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=False,
                              pin_memory=True)

    trn_skf_labels = (train_df.loc[trn_idx, 'patient'] +
                      train_df.loc[trn_idx, 'label']).values
    print('skf labels head : ', trn_skf_labels[:5])

    if args.DEBUG:
        print('\n#################################### DEBUG MODE')
    else:
        print('\n################################### MAIN MODE')
        print(trn_image_paths.shape, trn_labels.shape, trn_skf_labels.shape)

    # train set define
    train_dataset_dict = {}
    skf = StratifiedKFold(n_splits=args.n_folds,
                          shuffle=True,
                          random_state=args.SEED)
    nsplits = [
        val_idx for _, val_idx in skf.split(trn_image_paths, trn_skf_labels)
    ]
    print(nsplits)
    #np.save('nsplits.npy', nsplits)

    #print('\nload nsplits')
    #nsplits = np.load('nsplits.npy', allow_pickle=True)
    #print(nsplits)

    for idx, val_idx in enumerate(nsplits):  #trn_skf_labels

        sub_img_paths = np.array(trn_image_paths)[val_idx]
        sub_labels = np.array(trn_labels)[val_idx]
        sub_meta = np.array(trn_str_data)[val_idx]
        if args.DEBUG:
            sub_img_paths = sub_img_paths[:200]
            sub_labels = sub_labels[:200]
            sub_meta = sub_meta[:200]

        if idx == 1 or idx == 6:
            sub_img_paths = np.concatenate(
                [sub_img_paths, train_df.loc[unique_idx, 'image_path'].values])
            sub_labels = np.concatenate([sub_labels, labels[unique_idx]])
            sub_meta = np.concatenate([sub_meta, str_train_df[unique_idx]])

        train_transforms = create_train_transforms(args, args.input_size)
        #train_dataset = SleepDataset(args, sub_img_paths, sub_labels, train_transforms, use_masking=True, is_test=False)
        train_dataset_dict[idx] = [
            args, sub_img_paths, sub_meta, sub_labels, train_transforms
        ]
        print(f'train dataset complete {idx}/{args.n_folds}, ')

    print("numberr of train datasets: ", len(train_dataset_dict))

    # define model
    model = build_model(args, device)

    # optimizer definition
    optimizer = build_optimizer(args, model)
    #scheduler = build_scheduler(args, optimizer, len(train_loader))
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 9)
    scheduler = GradualWarmupSchedulerV2(optimizer,
                                         multiplier=1,
                                         total_epoch=1,
                                         after_scheduler=scheduler_cosine)

    if args.label_smoothing:
        criterion = LabelSmoothingLoss(classes=args.num_classes,
                                       smoothing=args.label_smoothing_ratio)
    else:
        criterion = nn.CrossEntropyLoss()

    trn_cfg = {
        'train_datasets': train_dataset_dict,
        'valid_loader': valid_loader,
        'model': model,
        'criterion': criterion,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'device': device,
        'fold_num': 0,
    }

    train(args, trn_cfg)
Ejemplo n.º 6
0
    def __init__(self,
                 models_dict,
                 optimizer_task,
                 source_loader,
                 test_source_loader,
                 target_loader,
                 nadir_slack,
                 alpha,
                 patience,
                 factor,
                 label_smoothing,
                 warmup_its,
                 lr_threshold,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 cuda=True,
                 logging=False,
                 ablation='no',
                 train_mode='hv'):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt_task = os.path.join(
            self.checkpoint_path, 'task' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'task_checkpoint_{}ep.pt')
        self.save_epoch_fmt_domain = os.path.join(
            self.checkpoint_path, 'Domain_{}' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'Domain_{}.pt')

        self.cuda_mode = cuda
        self.feature_extractor = models_dict['feature_extractor']
        self.task_classifier = models_dict['task_classifier']
        self.domain_discriminator_list = models_dict[
            'domain_discriminator_list']
        self.optimizer_task = optimizer_task
        self.source_loader = source_loader
        self.test_source_loader = test_source_loader
        self.target_loader = target_loader
        self.history = {
            'loss_task': [],
            'hypervolume': [],
            'loss_domain': [],
            'accuracy_source': [],
            'accuracy_target': []
        }
        self.cur_epoch = 0
        self.total_iter = 0
        self.nadir_slack = nadir_slack
        self.alpha = alpha
        self.ablation = ablation
        self.train_mode = train_mode
        self.device = next(self.feature_extractor.parameters()).device

        its_per_epoch = len(source_loader.dataset) // (
            source_loader.batch_size) + 1 if len(source_loader.dataset) % (
                source_loader.batch_size) > 0 else len(
                    source_loader.dataset) // (source_loader.batch_size)
        patience = patience * (1 + its_per_epoch)
        self.after_scheduler_task = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_task,
            factor=factor,
            patience=patience,
            verbose=True if verbose > 0 else False,
            threshold=lr_threshold,
            min_lr=1e-7)
        self.after_scheduler_disc_list = [
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                disc.optimizer,
                factor=factor,
                patience=patience,
                verbose=True if verbose > 0 else False,
                threshold=lr_threshold,
                min_lr=1e-7) for disc in self.domain_discriminator_list
        ]
        self.verbose = verbose
        self.save_cp = save_cp

        self.scheduler_task = GradualWarmupScheduler(
            self.optimizer_task,
            total_epoch=warmup_its,
            after_scheduler=self.after_scheduler_task)
        self.scheduler_disc_list = [
            GradualWarmupScheduler(self.domain_discriminator_list[i].optimizer,
                                   total_epoch=warmup_its,
                                   after_scheduler=sch_disc)
            for i, sch_disc in enumerate(self.after_scheduler_disc_list)
        ]

        if checkpoint_epoch is not None:
            self.load_checkpoint(checkpoint_epoch)

        self.logging = logging
        if self.logging:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter()

        if label_smoothing > 0.0:
            self.ce_criterion = LabelSmoothingLoss(label_smoothing,
                                                   lbl_set_size=7)
        else:
            self.ce_criterion = torch.nn.CrossEntropyLoss(
            )  #torch.nn.NLLLoss()#

        #loss_domain_discriminator
        weight = torch.tensor([2.0 / 3.0, 1.0 / 3.0]).to(self.device)
        #d_cr=torch.nn.CrossEntropyLoss(weight=weight)
        self.d_cr = torch.nn.NLLLoss(weight=weight)
Ejemplo n.º 7
0
def main():
    os.makedirs(SAVEPATH, exist_ok=True)
    print('save path:', SAVEPATH)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    print('weight_decay:', WEIGHTDECAY)
    print('momentum:', MOMENTUM)
    print('batch_size:', BATCHSIZE)
    print('lr:', LR)
    print('epoch:', EPOCHS)
    print('Label smoothing:', LABELSMOOTH)
    print('Stochastic Weight Averaging:', SWA)
    if SWA:
        print('Swa lr:', SWA_LR)
        print('Swa start epoch:', SWA_START)
    print('Cutout augmentation:', CUTOUT)
    if CUTOUT:
        print('Cutout size:', CUTOUTSIZE)
    print('Activation:', ACTIVATION)

    # get model
    model = get_seresnet_cifar(activation=ACTIVATION)

    # get loss function
    if LABELSMOOTH:
        criterion = LabelSmoothingLoss(classes=10, smoothing=0.1)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LR,
                                momentum=MOMENTUM,
                                weight_decay=WEIGHTDECAY,
                                nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                           T_max=EPOCHS,
                                                           eta_min=0)

    model = model.to(device)
    criterion = criterion.to(device)

    # Check number of parameters your model
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {pytorch_total_params}")
    if int(pytorch_total_params) > 2000000:
        print('Your model has the number of parameters more than 2 millions..')
        return

    if SWA:
        # apply swa
        swa_model = AveragedModel(model)
        swa_scheduler = SWALR(optimizer, swa_lr=SWA_LR)
        swa_total_params = sum(p.numel() for p in swa_model.parameters())
        print(f"Swa parameters: {swa_total_params}")

    # cinic mean, std
    normalize = transforms.Normalize(mean=[0.47889522, 0.47227842, 0.43047404],
                                     std=[0.24205776, 0.23828046, 0.25874835])

    if CUTOUT:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize,
            Cutout(size=CUTOUTSIZE)
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])

    train_dataset = torchvision.datasets.ImageFolder('/content/train',
                                                     transform=train_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=BATCHSIZE,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)

    # colab reload
    start_epoch = 0
    if os.path.isfile(os.path.join(SAVEPATH, 'latest_checkpoint.pth')):
        checkpoint = torch.load(os.path.join(SAVEPATH,
                                             'latest_checkpoint.pth'))
        start_epoch = checkpoint['epoch']
        scheduler.load_state_dict(checkpoint['scheduler'])
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if SWA:
            swa_scheduler.load_state_dict(checkpoint['swa_scheduler'])
            swa_model.load_state_dict(checkpoint['swa_model'])
        print(start_epoch, 'load parameter')

    for epoch in range(start_epoch, EPOCHS):
        print("\n----- epoch: {}, lr: {} -----".format(
            epoch, optimizer.param_groups[0]["lr"]))

        # train for one epoch
        start_time = time.time()
        train(train_loader, epoch, model, optimizer, criterion, device)
        elapsed_time = time.time() - start_time
        print('==> {:.2f} seconds to train this epoch\n'.format(elapsed_time))

        # learning rate scheduling
        if SWA and epoch > SWA_START:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()

        if SWA:
            checkpoint = {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'swa_model': swa_model.state_dict(),
                'swa_scheduler': swa_scheduler.state_dict()
            }
        else:
            checkpoint = {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }
        torch.save(checkpoint, os.path.join(SAVEPATH, 'latest_checkpoint.pth'))
        if epoch % 10 == 0:
            torch.save(checkpoint,
                       os.path.join(SAVEPATH, '%d_checkpoint.pth' % epoch))
# load pretrained weights if possible
pkl_path = None

try:
    model.load_state_dict(torch.load(pkl_path, map_location=device))
    #model2.load_state_dict(torch.load(pkl_path, map_location=device))
    print("\n--------model restored--------\n")
    #print("\n--------model2 restored--------\n")
except:
    print("\n--------model not restored--------\n")
    pass

# loss function
#loss_fn = MOD_CrossEntropyLoss()
loss_fn = LabelSmoothingLoss(classes=1000, batch_size=config.batch_size)

# parameters
lr = config.lr
optimizer = optim.Adam(model.parameters(), weight_decay=0.0, lr=lr)
#optimizer_smoothing = optim.Adam(model2.parameters(), betas=[.9, .999], weight_decay=0.0, lr=lr)

scheduler = optim.lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = config.num_epochs
log_interval = 100

# DataLoader
# Train Dataset & Loader
print("Data Loading ...")
trainset = Dataset(config.traindata_dir)
trainloader = create_loader(dataset=trainset,
Ejemplo n.º 9
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #============Data Load==============#
    #===================================#

    train_dat = pd.read_csv(os.path.join(args.data_path, 'news_train.csv'))
    train_dat_num = int(len(train_dat) * (1-args.valid_percent))

    print('Data Load & Setting!')
    with open(os.path.join(args.save_path, 'preprocessed.pkl'), 'rb') as f:
        data_ = pickle.load(f)
        src_vocab_num_dict = dict()

        total_train_text_indices_spm = data_['total_train_text_indices_spm']
        total_valid_text_indices_spm = data_['total_valid_text_indices_spm']
        total_train_text_indices_khaiii = data_['total_train_text_indices_khaiii']
        total_valid_text_indices_khaiii = data_['total_valid_text_indices_khaiii']
        total_train_text_indices_konlpy = data_['total_train_text_indices_konlpy']
        total_valid_text_indices_konlpy = data_['total_valid_text_indices_konlpy']
        train_content_indices_spm = data_['train_content_indices_spm']
        valid_content_indices_spm = data_['valid_content_indices_spm']
        train_content_indices_khaiii = data_['train_content_indices_khaiii']
        valid_content_indices_khaiii = data_['valid_content_indices_khaiii']
        train_content_indices_konlpy = data_['train_content_indices_konlpy']
        valid_content_indices_konlpy = data_['valid_content_indices_konlpy']
        train_date_list = data_['train_date_list']
        valid_date_list = data_['valid_date_list']
        train_ord_list = data_['train_ord_list']
        valid_ord_list = data_['valid_ord_list']
        train_id_list = data_['train_id_list']
        valid_id_list = data_['valid_id_list']
        train_info_list = data_['train_info_list']
        valid_info_list = data_['valid_info_list']
        word2id_spm = data_['word2id_spm']
        word2id_khaiii = data_['word2id_khaiii']
        word2id_konlpy = data_['word2id_konlpy']

        src_vocab_num_dict['spm'] = len(word2id_spm.keys())
        src_vocab_num_dict['khaiii'] = len(word2id_khaiii.keys())
        src_vocab_num_dict['konlpy'] = len(word2id_konlpy.keys())
        del data_

    dataset_dict = {
        'train': CustomDataset(total_train_text_indices_spm, total_train_text_indices_khaiii, 
                               total_train_text_indices_konlpy,
                               train_content_indices_spm, train_content_indices_khaiii, 
                               train_content_indices_konlpy, train_date_list, 
                               train_ord_list, train_id_list, train_info_list,
                               isTrain=True, min_len=args.min_len, max_len=args.max_len),
        'valid': CustomDataset(total_valid_text_indices_spm, total_valid_text_indices_khaiii, 
                               total_valid_text_indices_konlpy,
                               valid_content_indices_spm, valid_content_indices_khaiii, 
                               valid_content_indices_konlpy, valid_date_list, 
                               valid_ord_list, valid_id_list, valid_info_list,
                               isTrain=True, min_len=args.min_len, max_len=args.max_len),
    }
    dataloader_dict = {
        'train': DataLoader(dataset_dict['train'], collate_fn=PadCollate(), drop_last=True,
                            batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True),
        'valid': DataLoader(dataset_dict['valid'], collate_fn=PadCollate(), drop_last=True,
                            batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True)
    }
    print(f'Total number of trainingsets  iterations - {len(dataset_dict["train"])}, {len(dataloader_dict["train"])}')
    print(f'{train_dat_num - len(dataset_dict["train"])} data is exceptd.')

    #===================================#
    #===========Model Setting===========#
    #===================================#

    print("Build model")
    model = Total_model(args.model_type, src_vocab_num_dict, trg_num=2, pad_idx=args.pad_idx, bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx, max_len=args.max_len, d_model=args.d_model,
                        d_embedding=args.d_embedding, n_head=args.n_head, d_k=args.d_k, d_v=args.d_v,
                        dim_feedforward=args.dim_feedforward, dropout=args.dropout,
                        bilinear=args.bilinear, num_transformer_layer=args.num_transformer_layer,
                        num_rnn_layer=args.num_rnn_layer, device=device)
    if args.Ralamb:
        optimizer = Ralamb(params=filter(lambda p: p.requires_grad, model.parameters()), 
                           lr=args.max_lr, weight_decay=args.w_decay)
    else:
        optimizer = optim.SGD(model.parameters(), lr=args.max_lr, momentum=args.momentum,
                              weight_decay=args.w_decay)
    # optimizer = optim_lib.Lamb(params=model.parameters(), 
    #                        lr=args.max_lr, weight_decay=args.w_decay)

    if args.n_warmup_epochs != 0:
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.n_warmup_epochs*len(dataloader_dict['train']), 
                                        t_total=len(dataloader_dict['train'])*args.num_epoch)
    else:
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, 
                                      patience=len(dataloader_dict['train'])/1.5)
    criterion = LabelSmoothingLoss(classes=2, smoothing=args.label_smoothing)
    model.to(device)

    #===================================#
    #===========Model Training==========#
    #===================================#

    best_val_loss = None

    if not os.path.exists(args.model_path):
        os.mkdir(args.model_path)

    for e in range(args.num_epoch):
        start_time_e = time.time()
        print(f'Model Fitting: [{e+1}/{args.num_epoch}]')
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
                freq = 0
            if phase == 'valid':
                model.eval()
                val_loss = 0
                val_acc = 0
                false_id_list, false_logit_list = list(), list()
            for i, (total_src_spm, total_src_khaiii, total_src_konlpy, src_spm, src_khaiii, src_konlpy, date, order, id_, trg) in enumerate(dataloader_dict[phase]):

                # Optimizer setting
                optimizer.zero_grad()

                # Source, Target sentence setting
                total_src_spm = total_src_spm.to(device)
                total_src_khaiii = total_src_khaiii.to(device)
                total_src_konlpy = total_src_konlpy.to(device)
                src_spm = src_spm.to(device)
                src_khaiii = src_khaiii.to(device)
                src_konlpy = src_konlpy.to(device)
                trg = trg.to(device)

                # Model / Calculate loss
                with torch.set_grad_enabled(phase == 'train'):
                    predicted_logit = model(total_src_spm, total_src_khaiii, total_src_konlpy, src_spm, src_khaiii, src_konlpy)

                    # If phase train, then backward loss and step optimizer and scheduler
                    if phase == 'train':
                        loss = criterion(predicted_logit, trg)
                        loss.backward()
                        clip_grad_norm_(model.parameters(), args.grad_clip)
                        optimizer.step()
                        if args.n_warmup_epochs != 0:
                            scheduler.step()
                        else:
                            scheduler.step(loss)
                        # Print loss value only training
                        if freq == args.print_freq or freq == 0 or i == len(dataloader_dict['train']):
                            total_loss = loss.item()
                            _, predicted = predicted_logit.max(dim=1)
                            accuracy = sum(predicted == trg).item() / predicted.size(0)
                            print("[Epoch:%d][%d/%d] train_loss:%5.3f | Accuracy:%2.3f | lr:%1.6f | spend_time:%5.2fmin"
                                    % (e+1, i, len(dataloader_dict['train']), total_loss, accuracy, 
                                    optimizer.param_groups[0]['lr'], (time.time() - start_time_e) / 60))
                            freq = 0
                        freq += 1
                    if phase == 'valid':
                        loss = F.cross_entropy(predicted_logit, trg)
                        val_loss += loss.item()
                        _, predicted = predicted_logit.max(dim=1)
                        # Setting
                        predicted_matching = (predicted == trg)
                        logit_clone = F.softmax(predicted_logit.cpu().clone(), dim=1).numpy()
                        # Calculate
                        accuracy = sum(predicted_matching).item() / predicted.size(0)
                        false_id_list.extend([id_[i] for i, x in enumerate(predicted_matching) if not x])
                        false_logit_list.extend(logit_clone[[i for i, x in enumerate(predicted_matching) if not x]])
                        val_acc += accuracy

            # Finishing iteration
            if phase == 'valid':
                val_loss /= len(dataloader_dict['valid'])
                val_acc /= len(dataloader_dict['valid'])
                print("[Epoch:%d] val_loss:%5.3f | Accuracy:%5.2f | spend_time:%5.2fmin"
                        % (e+1, val_loss, val_acc, (time.time() - start_time_e) / 60))
                if not best_val_loss or val_loss < best_val_loss:
                    print("[!] saving model...")
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
                    torch.save(model.state_dict(), 
                               os.path.join(args.model_path, f'model_saved.pt'))
                    best_val_loss = val_loss
                    wrong_id_list = false_id_list
                    wrong_logit_list = false_logit_list

    #===================================#
    #============Result save============#
    #===================================#

    # 1) Path setting
    if not os.path.exists(args.results_path):
        os.mkdir(args.results_path)

    if not os.path.isfile(os.path.join(args.results_path, 'results.csv')):
        column_list_results = ['date_time', 'best_val_loss', 'tokenizer', 'valid_percent', 
                               'vocab_size', 'num_epoch', 'batch_size', 'max_len', 'n_warmup_epochs', 
                               'max_lr', 'momentum', 'w_decay', 'dropout', 'grad_clip', 'model_type', 
                               'bilinear', 'num_transformer_layer', 'num_rnn_layer', 'd_model', 
                               'd_embedding', 'd_k', 'd_v', 'n_head', 'dim_feedforward']
        pd.DataFrame(columns=column_list_results).to_csv(os.path.join(args.results_path, 'results.csv'), index=False)

    if not os.path.isfile(os.path.join(args.results_path, 'wrong_list.csv')):
        column_list_wrong = ['date_time', 'id_', 'title', 'content', '0', '1', 'info']
        pd.DataFrame(columns=column_list_wrong).to_csv(os.path.join(args.results_path, 'wrong_list.csv'), index=False)

    results_dat = pd.read_csv(os.path.join(args.results_path, 'results.csv'))
    wrong_dat_total = pd.read_csv(os.path.join(args.results_path, 'wrong_list.csv'))

    # 2) Model setting save
    new_row = {
        'date_time':datetime.datetime.today().strftime('%m/%d/%H:%M'),
        'best_val_loss': best_val_loss,
        'tokenizer': args.sentencepiece_tokenizer,
        'valid_percent': args.valid_percent,
        'vocab_size': args.vocab_size,
        'num_epoch': args.num_epoch,
        'batch_size': args.batch_size,
        'max_len': args.max_len,
        'n_warmup_epochs': args.n_warmup_epochs,
        'max_lr': args.max_lr,
        'momentum': args.momentum,
        'w_decay': args.w_decay,
        'dropout': args.dropout,
        'grad_clip': args.grad_clip,
        'model_type': args.model_type,
        'bilinear': args.bilinear,
        'num_transformer_layer': args.num_transformer_layer,
        'num_rnn_layer': args.num_rnn_layer,
        'd_model': args.d_model,
        'd_embedding': args.d_embedding,
        'd_k': args.d_k,
        'd_v': args.d_v,
        'n_head': args.n_head,
        'dim_feedforward': args.dim_feedforward,
        'label_smoothing': args.label_smoothing
    }
    results_dat = results_dat.append(new_row, ignore_index=True)
    results_dat.to_csv(os.path.join(args.results_path, 'results.csv'), index=False)

    # 3) Worng ID list save
    train_dat['id_'] = train_dat['n_id'] + '_' + train_dat['ord'].astype(str)

    wrong_dat = pd.DataFrame(np.stack(wrong_logit_list))
    wrong_dat['date_time'] = [datetime.datetime.today().strftime('%m/%d/%H:%M') for _ in range(len(wrong_dat))]
    wrong_dat['id_'] = wrong_id_list
    wrong_dat = wrong_dat.merge(train_dat[['id_', 'title', 'content', 'info']], on='id_')
    wrong_dat = wrong_dat[['date_time', 'id_', 'title', 'content', 0, 1, 'info']]

    wrong_dat_total = pd.concat([wrong_dat_total, wrong_dat], axis=0)
    wrong_dat_total.to_csv(os.path.join(args.results_path, 'wrong_list.csv'), index=False)