Beispiel #1
0
    def validate(self,dev_data,threshold=0.1):
        val_loss=[]
        # import pdb; pdb.set_trace()
        for i, item in enumerate(dev_data):
            Qc,Qw,q_mask,Ec,Ew,e_mask,As,Ae = [i.to(self.device) for i in item]
            As_, Ae_ =  self.model([Qc,Qw,q_mask,Ec,Ew,e_mask])

            #cal loss
            As_loss,Ae_loss=focal_loss(As,As_,self.device) ,focal_loss(Ae,Ae_,self.device)
            mask=e_mask==1
            loss=(As_loss.masked_select(mask).sum() + Ae_loss.masked_select(mask).sum()) /  e_mask.sum()
            if (i+1)%self.print_step==0 or i==len(dev_data)-1:
                logger.info("In Validation: Step / All Step : {} / {} \t Loss of every char : {}"\
                    .format(i+1,len(dev_data),loss.item()*100))
            val_loss.append(loss.item())
            
            
            As_,Ae_,As,Ae = [ i.masked_select(mask).cpu().numpy() for i in [As_,Ae_,As,Ae]]
            As_,Ae_ = np.where(As_>threshold,1,0), np.where(Ae_>threshold,1,0)
            As,Ae = As.astype(int),Ae.astype(int)
            
            acc,prec,recall,f1=binary_confusion_matrix_evaluate(As,As_)
            
            logger.info('START EVALUATION :\t Acc : {}\t Prec : {}\t Recall : {}\t F1-score : {}'\
                .format(acc,prec,recall,f1))
            acc,prec,recall,f1=binary_confusion_matrix_evaluate(Ae,Ae_)
            logger.info('END EVALUATION :\t Acc : {}\t Prec : {}\t Recall : {}\t F1-score : {}'\
                .format(acc,prec,recall,f1))
            # [ , seq_len]
        l=sum(val_loss)/len(val_loss)
        logger.info('In Validation, Average Loss : {}'.format(l*100))
        if l<self._val_loss:
            logger.info('Update best Model in Valiation Dataset')
            self._val_loss=l
            self.best_model=deepcopy(self.model)
Beispiel #2
0
    def train(self,train_data,dev_data,threshold=0.1):
        for epoch in range(self.epoches):
            self.model.train()
            for i,item in enumerate(train_data):
                self.optimizer.zero_grad()
                Qc,Qw,q_mask,Ec,Ew,e_mask,As,Ae = [i.to(self.device) for i in item]
                As_, Ae_ = self.model([Qc,Qw,q_mask,Ec,Ew,e_mask])
                As_loss=focal_loss(As,As_,self.device)
                Ae_loss=focal_loss(Ae,Ae_,self.device)
                # batch_size, max_seq_len_e 
                
                mask=e_mask==1
                loss=(As_loss.masked_select(mask).sum()+Ae_loss.masked_select(mask).sum()) / e_mask.sum()
                loss.backward()
                self.optimizer.step()

                if (i+1)%self.print_step==0 or i==len(train_data)-1:
                    logger.info("In Training : Epoch : {} \t Step / All Step : {} / {} \t Loss of every char : {}"\
                        .format(epoch+1, i+1,len(train_data),loss.item()*100))

                #debug
                # if i==2000:
                #     break
            
            self.model.eval()
            with torch.no_grad():
                self.validate(dev_data)
Beispiel #3
0
    def __init__(self, batch_norm_mode, depth, model_root_channel=8, img_size=256, batch_size=20, n_channel=1, n_class=2):

        self.drop_rate=tf.placeholder(tf.float32)
        self.training=tf.placeholder(tf.bool)

        self.batch_size=batch_size
        self.model_channel=model_root_channel
        self.batch_mode = batch_norm_mode
        self.depth_n = depth

        self.X=tf.placeholder(tf.float32, [None, img_size, img_size, n_channel], name='X')
        self.Y=tf.placeholder(tf.float32, [None, img_size, img_size, n_class], name='Y')

        self.logits=self.neural_net()

        self.foreground_predicted, self.background_predicted=tf.split(tf.nn.softmax(self.logits), [1, 1], 3)

        self.foreground_truth, self.background_truth=tf.split(self.Y, [1, 1], 3)

        with tf.name_scope('Loss'):
            # # Cross_Entropy
            # self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.Y))

            # # Dice_Loss
            # self.loss=utils.dice_loss(output=self.logits, target=self.Y)

            # # Focal_Loss
            self.loss=utils.focal_loss(output=self.logits, target=self.Y, use_class=True, gamma=2, smooth=1e-8)

        with tf.name_scope('Metrics'):
            self.accuracy=utils.mean_iou(self.foreground_predicted,self.foreground_truth)

        # TB
        tf.summary.scalar('loss', self.loss)
        tf.summary.scalar('accuracy', self.accuracy)
Beispiel #4
0
 def setup_loss(self):
     if self.hparams.focal_loss > 0:
         self.gamma = tf.Variable(self.hparams.focal_loss,dtype=tf.float32, trainable=False)
         label_losses = focal_loss(self.target_labels, self.final_logits, self.gamma)
     else:
         label_losses = tf.losses.softmax_cross_entropy(onehot_labels=self.target_labels, logits=self.final_logits, reduction=tf.losses.Reduction.MEAN)
     self.losses = label_losses
    def train(self,
              x_train,
              y_train,
              x_valid,
              y_valid,
              learning_rate,
              es_patience,
              es_min_delta,
              reduce_lr_patience,
              reduce_lr_factor,
              batch_size,
              epochs,
              loss_function,
              train_sample_weight=None,
              valid_sample_weight=None):
        """model training"""

        # loss function setting
        if loss_function == 'categorical_crossentropy':
            self.model.compile(loss='categorical_crossentropy',
                               optimizer=Adam(lr=learning_rate),
                               metrics=['accuracy'])
        elif loss_function == 'focal_loss':
            self.model.compile(loss=focal_loss(alpha=2),
                               optimizer=Adam(lr=learning_rate),
                               metrics=['accuracy'])

        # set early stopping criteria and reduce learning rate
        es = EarlyStopping(monitor='val_loss',
                           patience=es_patience,
                           min_delta=es_min_delta,
                           mode='min')
        reduce_lr = ReduceLROnPlateau(patience=reduce_lr_patience,
                                      verbose=1,
                                      factor=reduce_lr_factor)

        # format validation data into a tuple
        if valid_sample_weight is not None:
            validation_data_tuple = (x_valid, y_valid, valid_sample_weight)
        else:
            validation_data_tuple = (x_valid, y_valid)

        # train GCNN model
        self.model.fit(x_train,
                       y_train,
                       batch_size=batch_size,
                       shuffle=True,
                       epochs=epochs,
                       sample_weight=train_sample_weight,
                       validation_data=validation_data_tuple,
                       callbacks=[es, reduce_lr])
Beispiel #6
0
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device, dtype=torch.long)
        optimizer.zero_grad()
        output = model(data)
        loss = focal_loss(
            output, target, gamma=2, ignore_index=255
        )  #focal_loss(output, target, gamma=2)#F.cross_entropy(output, target, ignore_index=255)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
Beispiel #7
0
    def __init__(self, rel_list):
        self.rel_list = rel_list

        self.device = torch.device(
            'cuda:1' if torch.cuda.is_available() else 'cpu')
        self.bert = BertModel.from_pretrained(config.BertPath)

        self.model = CBT(self.bert, len(self.rel_list)).to(self.device)
        self.epoches = 100
        self.lr = 1e-5

        self.best_model = CBT(self.bert, len(self.rel_list)).to(self.device)
        self.best_loss = 1e12
        self.print_step = 15
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        self.loss = lambda y_true, y_pred: focal_loss(y_true, y_pred, self.
                                                      device)
Beispiel #8
0
def train(params):
    from utils import reload_data
    sample_rate, seed, batch_size, category, interval = 0.7, 2019, params.batch_size, params.category + 2, int(
        math.ceil(100. / params.category))
    lr = params.learning_rate
    data_dir, file_ptn = params.dataset, params.source
    dataframes = reload_data(data_dir, file_ptn)
    trainset, testset = train_test_split(dataframes,
                                         train_size=sample_rate,
                                         test_size=1 - sample_rate,
                                         random_state=seed)
    train_gen = preprocessing(trainset,
                              dropout=params.dropout,
                              category=category,
                              interval=interval)
    validation_gen = preprocessing(testset,
                                   is_training=False,
                                   category=category,
                                   interval=interval)
    print(trainset.groupby(["age"])["age"].agg("count"))

    print(testset.groupby(["age"]).agg(["count"]))
    age_dist = [
        trainset["age"][(trainset.age >= x - 10)
                        & (trainset.age <= x)].count()
        for x in xrange(10, 101, 10)
    ]
    age_dist = [age_dist[0]] + age_dist + [age_dist[-1]]
    print(age_dist)

    if params.pretrain_path and os.path.exists(params.pretrain_path):
        models = load_model(params.pretrain_path,
                            custom_objects={
                                "pool2d": pool2d,
                                "ReLU": ReLU,
                                "BatchNormalization": BatchNormalization,
                                "tf": tf,
                                "focal_loss_fixed": focal_loss(age_dist)
                            })
    else:
        models = build_net(category,
                           using_SE=params.se_net,
                           using_white_norm=params.white_norm)
    adam = Adam(lr=lr)
    #cate_weight = K.variable(params.weight_factor)

    models.compile(
        optimizer=adam,
        loss=["mae", focal_loss(age_dist)],  # "kullback_leibler_divergence"
        metrics={
            "age": "mae",
            "W1": "mae"
        },
        loss_weights=[1, params.weight_factor])
    W2 = models.get_layer("age")

    print(models.summary())

    #thres_callback = ThresCallback(cate_weight, models.get_layer("age_mean_absolute_error"), 10, 10)

    def get_weights(epoch, loggs):
        print(epoch, K.get_value(models.optimizer.lr), W2.get_weights())

    callbacks = [
        ModelCheckpoint(params.save_path,
                        monitor='val_age_mean_absolute_error',
                        verbose=1,
                        save_best_only=True,
                        mode='min'),
        ModelCheckpoint("train_" + params.save_path,
                        monitor='age_mean_absolute_error',
                        verbose=1,
                        save_best_only=True,
                        mode='min'),
        TensorBoard(log_dir=params.log_dir,
                    batch_size=batch_size,
                    write_images=True,
                    update_freq='epoch'),
        #EarlyStopping(monitor='val_age_mean_absolute_error', patience=10, verbose=0, mode='min'),
        #LearningRateScheduler(lambda epoch: lr - 0.0001 * epoch // 10),
        ReduceLROnPlateau(monitor='val_age_mean_absolute_error',
                          factor=0.1,
                          patience=10,
                          min_lr=0.00001),
        LambdaCallback(on_epoch_end=get_weights)
    ]
    history = models.fit_generator(train_gen,
                                   steps_per_epoch=len(trainset) / batch_size,
                                   epochs=160,
                                   callbacks=callbacks,
                                   validation_data=validation_gen,
                                   validation_steps=len(testset) / batch_size *
                                   3)
Beispiel #9
0
def train(params):
    from utils import reload_data
    if params.fp16:
        os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'] = '1'

    sample_rate, seed, batch_size, category, interval = 0.8, 2019, params.batch_size, params.category + 2, int(math.ceil(100. / params.category))
    lr = params.learning_rate
    data_dir, file_ptn = params.dataset, params.source
    dataframes = reload_data(data_dir, file_ptn)
    trainset, testset = train_test_split(dataframes, train_size=sample_rate, test_size=1-sample_rate, random_state=seed)
    print(testset.gender)
    train_gen = preprocessing(trainset, dropout=params.dropout, category=category, interval=interval)
    validation_gen = preprocessing(testset, dropout=0, is_training=False, category=category, interval=interval)
    print(trainset.groupby(["age"])["age"].agg("count"))

    print(testset.groupby(["age"]).agg(["count"]))
    age_dist = [trainset["age"][(trainset.age >= x -10) & (trainset.age <= x)].count() for x in range(10, 10 * params.category + 1, 10)]
    age_dist = [age_dist[0]] + age_dist + [age_dist[-1]]
    gender_dist = [trainset["gender"][trainset.gender == 0].count(), trainset["gender"][trainset.gender == 1].count()]
    print(age_dist, gender_dist)

    models = build_net3(category, using_SE=params.se_net, using_white_norm=params.white_norm)
    if params.pretrain_path:
        #models = load_model(params.pretrain_path, custom_objects={"pool2d": pool2d, "ReLU": ReLU, "BatchNormalization": BatchNormalization, "tf": tf, "focal_loss_fixed": focal_loss(age_dist)})
        ret = models.load_weights(params.pretrain_path)
        model_refresh_without_nan(models)

    #optim = SGD(lr=lr, momentum=0.9)
    if params.freeze:
        converter = tf.lite.TFLiteConverter.from_keras_model(models)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        #converter.target_spec.supported_types = [tf.compat.v1.lite.constants.FLOAT16]
        tflite_model = converter.convert()
        open("profile.tflite", "wb").write(tflite_model) 
        return


    epoch_nums = 160
    optim = Adam(lr=lr)
    if params.fp16:
        optim = tf.train.experimental.enable_mixed_precision_graph_rewrite(optim)

    print("-----outputs-----", models.outputs)

    models.compile(
        optimizer=optim,
        loss=["mae", focal_loss(age_dist), "categorical_crossentropy"],  # "kullback_leibler_divergence"
        #loss=["mae", focal_loss(age_dist), focal_loss(gender_dist)],  # "kullback_leibler_divergence"
        metrics={"age": "mae", "gender": "acc", "W1": "mae"},
        loss_weights=[1, 10, 30],
        #loss_weights=[1, 10, 100],
        experimental_run_tf_function=False
    )
    W2 = models.get_layer("age")

    callbacks = [
        ModelCheckpoint(params.save_path, monitor='val_age_mae', verbose=1, save_best_only=True, save_weights_only=True, mode='min'),
        ModelCheckpoint(params.save_path, monitor='val_gender_acc', verbose=1, save_best_only=True, save_weights_only=True),
        TensorBoard(log_dir=params.log_dir, batch_size=batch_size, write_images=True, update_freq='epoch'),
        #ReduceLROnPlateau(monitor='val_age_mae', factor=0.1, patience=10, min_lr=0.00001),
        CosineAnnealingScheduler(epoch_nums, lr, lr / 100)
    ]
    if not params.test:
        history = models.fit(train_gen, steps_per_epoch=len(trainset) / batch_size, epochs=epoch_nums, callbacks=callbacks, validation_data=validation_gen, validation_steps=len(testset) / batch_size, workers=1)
    else:
        models.evaluate(validation_gen, steps=len(testset) / batch_size)
def train(args):
    sub_dir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(args.log_dir, sub_dir)
    model_dir = os.path.join(args.model_dir, sub_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    train_logger = utils.Logger('train', file_name=os.path.join(
        log_dir, 'train.log'), control_log=False)
    test_logger = utils.Logger(
        'test', file_name=os.path.join(log_dir, 'test.log'))

    utils.save_arguments(args, os.path.join(log_dir, 'arguments.txt'))

    # data
    split_dataset(args.dataset_path, args.train_ratio, args.pos_ratio)
    base_dir = os.path.dirname(args.dataset_path)
    train_dataset = load_data(os.path.join(base_dir, 'train.txt'))
    test_dataset = load_data(os.path.join(base_dir, 'test.txt'))

    dataset_size = train_dataset.num_examples
    train_logger.info('dataset size: %s' % dataset_size)

    tf.reset_default_graph()
    with tf.Graph().as_default(), tf.device('/gpu:3'):
        tf.set_random_seed(10)
        x = tf.placeholder(tf.float32, shape=[None, args.seq_len, 1], name='input')
        y = tf.placeholder(tf.int64, shape=[None], name='label')
        one_hot_y = tf.one_hot(y, depth=2, dtype=tf.int64)
        is_training = tf.placeholder(tf.bool, name='training')

        net = Network(args.seq_len, args.tcn_channels, args.tcn_kernel_size, 
            args.tcn_dropout, args.embedding_size, args.num_classes, args.weight_decay)
        prelogits, logits, embeddings = net(x, is_training)

        # accuracy
        with tf.variable_scope('metrics'):
            tpr_op, fpr_op, g_mean_op, accuracy_op, f1_op = calc_accuracy(logits, y)

        # loss
        with tf.variable_scope('loss'):
            focal_loss_op = utils.focal_loss(y, logits, 5.0)
            cross_entropy_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=logits, labels=one_hot_y), name='cross_entropy')
            center_loss_op, centers, centers_update_op = utils.center_loss(
                prelogits, y, args.num_classes, args.center_loss_alpha)
            regularization_loss_op = tf.reduce_sum(tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES), name='l2_loss')
            loss_op = center_loss_op * args.center_loss_factor + \
                cross_entropy_op + regularization_loss_op# + L_loss_op * args.mwdn_loss_factor[0] + H_loss_op * args.mwdn_loss_factor[1]

        # optimizer
        with tf.variable_scope('optimizer'), tf.control_dependencies([centers_update_op]):
            global_step = tf.Variable(0, trainable=False, name='global_step')
            # optimizer_op = tf.train.MomentumOptimizer(learning_rate_op, 0.9, name='optimizer')
            optimizer = tf.train.AdamOptimizer(args.lr, name='optimizer')
            train_op = optimizer.minimize(loss_op, global_step)

        # summary
        tf.summary.scalar('total_loss', loss_op)
        tf.summary.scalar('l2_loss', regularization_loss_op)
        tf.summary.scalar('cross_entropy', cross_entropy_op)
        tf.summary.scalar('center_loss', center_loss_op)
        tf.summary.scalar('focal_loss', focal_loss_op)
        #tf.summary.scalar('l_loss', L_loss_op)
        #tf.summary.scalar('h_loss', H_loss_op)
        tf.summary.scalar('accuracy', accuracy_op)
        tf.summary.scalar('tpr', tpr_op)
        tf.summary.scalar('fpr', fpr_op)
        tf.summary.scalar('g_mean', g_mean_op)
        tf.summary.scalar('f1', f1_op)
        summary_op = tf.summary.merge_all()

        saver = tf.train.Saver(max_to_keep=100)
        config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=tf.GPUOptions(allow_growth=True))
        with tf.Session(config=config) as sess:
            writer = tf.summary.FileWriter(log_dir, sess.graph)
            tf.global_variables_initializer().run()

            if args.pretrained_model:
                ckpt = tf.train.get_checkpoint_state(args.pretrained_model)
                saver.restore(sess, ckpt.model_checkpoint_path)

            steps_per_epoch = np.ceil(dataset_size / args.batch_size).astype(int)
            batch_num_seq = range(steps_per_epoch)
            best_test_accuracy = 0.0
            best_test_fpr = 1.0
            try:
                for epoch in range(1, args.max_epochs+1):
                    batch_num_seq = tqdm(
                        batch_num_seq, desc='Epoch: {:d}'.format(epoch), ascii=True)
                    for step in batch_num_seq:
                        feature, label = train_dataset.next_batch(args.batch_size)
                        feature = np.reshape(feature, (args.batch_size, args.seq_len, 1))
                        if step % args.display == 0:
                            tensor_list = [train_op, summary_op, global_step, accuracy_op, tpr_op,
                                fpr_op, g_mean_op, cross_entropy_op, center_loss_op, focal_loss_op]
                            _, summary, train_step, accuracy, tpr, fpr, g_mean, cross_entropy, center_loss, focal_loss = sess.run(
                                tensor_list, feed_dict={x: feature, y: label, is_training: True})
                            train_logger.info('Train Step: %d, accuracy: %.3f%%, tpr: %.3f%%, fpr: %.3f%%, g_mean: %.3f%%, cross_entropy: %.4f, center_loss: %.4f, focal_loss: %.4f'
                                % (train_step, accuracy*100, tpr*100, fpr*100, g_mean*100, cross_entropy, center_loss, focal_loss))
                        else:
                            _, summary, train_step = sess.run([train_op, summary_op, global_step], feed_dict={
                                                              x: feature, y: label, is_training: True})
                        writer.add_summary(summary, global_step=train_step)
                        writer.flush()

                    # evaluate
                    num_batches = int(np.ceil(test_dataset.num_examples / args.batch_size))
                    accuracy_array = np.zeros((num_batches,), np.float32)
                    tpr_array = np.zeros((num_batches,), np.float32)
                    fpr_array = np.zeros((num_batches,), np.float32)
                    g_mean_array = np.zeros((num_batches,), np.float32)
                    for i in range(num_batches):
                        feature, label = test_dataset.next_batch(args.batch_size)
                        feature = np.reshape(feature, (args.batch_size, args.seq_len, 1))
                        tensor_list = [accuracy_op, tpr_op, fpr_op, g_mean_op]
                        feed_dict = {x: feature, y: label, is_training: False}
                        accuracy_array[i], tpr_array[i], fpr_array[i], g_mean_array[i] = sess.run(
                            tensor_list, feed_dict=feed_dict)
                    test_logger.info('Validation Epoch: %d, train_step: %d, accuracy: %.3f%%, tpr: %.3f%%, fpr: %.3f%%, g_mean: %.3f%%'
                        % (epoch, train_step, np.mean(accuracy_array)*100, np.mean(tpr_array)*100, np.mean(fpr_array)*100, np.mean(g_mean_array)*100))

                    test_accuracy = np.mean(accuracy_array)
                    test_fpr = np.mean(fpr_array)
                    if test_accuracy > best_test_accuracy:
                        best_test_accuracy = test_accuracy
                        saver.save(sess, os.path.join(model_dir, 'arc_fault'),
                                   global_step=train_step)
                    elif test_accuracy == best_test_accuracy and test_fpr < best_test_fpr:
                        best_test_fpr = test_fpr
                        saver.save(sess, os.path.join(model_dir, 'arc_fault'),
                                   global_step=train_step)

            except Exception as e:
                train_logger.error(e)
            writer.close()
Beispiel #11
0
def train_model(train_subjs,
                annots,
                all_images,
                label_mappings,
                save_path,
                out_type,
                weights=None,
                finetune_path=None,
                lr=0.001,
                lr_factor=None,
                a=0.5,
                k=1,
                num_epochs=100,
                batch_size=32,
                patience=10,
                augment=None,
                dropout_prob=None,
                labeling='hard',
                focal_params=None):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    fold_subjs, num_ims_fold = utils.fold_split(annots, train_subjs, k)
    phases = ['train', 'val'] if k > 1 else ['train']

    datasets = {}
    data_loaders = {}

    appear = (out_type in ['appear', 'both'])
    grade = (out_type in ['grade', 'both'])

    # don't weight losses if not running multi-task classification
    if out_type == 'appear': a = 1
    elif out_type == 'grade': a = 0

    # metrics to track
    logger = Metric_Logger(appear, grade)

    # for each fold in k-fold CV
    for fold in range(k):

        print('\nStarting fold', fold + 1)

        logger.init_fold(fold)
        fold_path = os.path.join(save_path, 'fold_' + str(fold + 1))
        if not os.path.exists(fold_path): os.mkdir(fold_path)
        num_epochs_no_imp = 0
        stop_flag = False  # whether to terminate early
        imp_on_lr = False  # whether any improvement was seen with the current LR

        # create data loaders
        datasets['train'] = Dataset(np.concatenate(np.delete(fold_subjs,
                                                             fold)),
                                    annots,
                                    all_images,
                                    label_mappings,
                                    augment=augment,
                                    target_type=labeling)
        data_loaders['train'] = torch.utils.data.DataLoader(
            datasets['train'], batch_size=batch_size, shuffle=True)
        if k > 1:
            datasets['val'] = Dataset(fold_subjs[fold], annots, all_images,
                                      label_mappings)
            data_loaders['val'] = torch.utils.data.DataLoader(
                datasets['val'], batch_size=batch_size, shuffle=False)

        # setup model and optimizer
        if finetune_path:
            finetune_fold_path = glob.glob(
                os.path.join(finetune_path, 'fold_' + str(fold + 1),
                             '*best*'))[0]
            checkpoint = torch.load(finetune_fold_path)
            model = Network(dropout_prob=dropout_prob)
            model.load_state_dict(checkpoint['state_dict'])
            model.unfreeze()
            trainable_params = [
                param for param in model.parameters() if param.requires_grad
            ]
            if appear: trainable_params += list(model.appear_fc.parameters())
            if grade: trainable_params += list(model.grade_fc.parameters())
            optimizer = torch.optim.Adam(trainable_params, lr=1e-4)
            best_val_loss = utils.get_best_val_loss(finetune_path, fold + 1)
        else:
            model = Network(weights=weights, dropout_prob=dropout_prob)
            model.freeze()
            trainable_params = []
            if appear: trainable_params += list(model.appear_fc.parameters())
            if grade: trainable_params += list(model.grade_fc.parameters())
            optimizer = torch.optim.Adam(trainable_params, lr=lr)
            best_val_loss = float('inf')
        model = model.to(device)

        # iterate over epochs
        for epoch in range(num_epochs):

            logger.init_epoch(epoch)

            for phase in phases:

                # set to train or eval mode
                if phase == 'train':
                    model.train()
                    #model.apply(set_bn_eval) # batchnorm in eval mode
                else:
                    model.eval()

                # make buffers for output probabilities and targets, for epoch metric calculations
                logger.init_phase(phase, len(datasets[phase]))

                # iterate over batches
                for batch in data_loaders[phase]:
                    images, appear_targets, appear_probs, grade_targets, grade_probs = batch

                    images = images.to(device)
                    if appear:
                        appear_targets = appear_targets.to(device)
                        appear_probs = appear_probs.to(device)
                    if grade:
                        grade_targets = grade_targets.to(device)
                        grade_probs = grade_probs.to(device)

                    optimizer.zero_grad()

                    # inference and gradient step
                    with torch.set_grad_enabled(phase == 'train'):
                        appear_out, grade_out = model(images, out_type)

                        if labeling in ['hard', 'sample']:
                            if focal_params:
                                loss_appear = utils.focal_loss(
                                    appear_out, appear_targets, focal_params,
                                    device, 'appear') if appear else 0
                                loss_grade = utils.focal_loss(
                                    grade_out, grade_targets, focal_params,
                                    device, 'grade') if grade else 0
                            else:
                                loss_appear = torch.nn.functional.cross_entropy(
                                    appear_out,
                                    appear_targets) if appear else 0
                                loss_grade = torch.nn.functional.cross_entropy(
                                    grade_out, grade_targets,
                                    ignore_index=-1) if grade else 0
                        elif labeling == 'soft':
                            loss_appear = utils.CE_loss_distr(
                                appear_out, appear_probs) if appear else 0
                            loss_grade = utils.CE_loss_distr(
                                grade_out, grade_probs,
                                ignore_index=-1) if grade else 0
                        loss = a * loss_appear + (1 - a) * loss_grade
                        logger.save_losses(loss_appear, loss_grade, loss)

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # store batch metrics
                    logger.batch_metrics(appear_out, appear_targets, grade_out,
                                         grade_targets)

                # store epoch metrics and get average total loss for phase epoch
                avg_loss = logger.phase_metrics()

                # model saving, LR stepping, and early stopping
                if phase == 'val':
                    state = {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    if epoch > 0:
                        os.remove(
                            glob.glob(os.path.join(fold_path,
                                                   'model_latest*'))[0])
                    torch.save(
                        state,
                        os.path.join(
                            fold_path,
                            'model_latest_ep{}_loss{:.2f}.pt'.format(
                                epoch + 1, avg_loss)))
                    if avg_loss < best_val_loss:
                        logger.save_best(fold_path)
                        prev_best = glob.glob(
                            os.path.join(fold_path, 'model_best*'))
                        if prev_best: os.remove(prev_best[0])
                        torch.save(
                            state,
                            os.path.join(
                                fold_path,
                                'model_best_ep{}_loss{:.2f}.pt'.format(
                                    epoch + 1, avg_loss)))
                        best_val_loss = avg_loss
                        num_epochs_no_imp = 0
                        imp_on_lr = True
                    else:
                        num_epochs_no_imp += 1
                        if num_epochs_no_imp >= patience:
                            if imp_on_lr and lr_factor:
                                for g in optimizer.param_groups:
                                    g['lr'] *= lr_factor
                                num_epochs_no_imp = 0
                                imp_on_lr = False
                                print("\nSTEPPING DOWN LR")
                            else:
                                stop_flag = True

            # save and print metrics
            logger.epoch_metrics(fold_path)
            if stop_flag:
                print("STOPPING EARLY\n")
                break
def main(_):
    """main program"""

    main_start = time.time()
    # set random seed to get reproducible results
    seed(3000)
    set_random_seed(3000)

    # disable TensorFlow default messages
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    tf.logging.set_verbosity(tf.logging.ERROR)

    # validate the flags
    validate_flags_or_throw(FLAGS)

    # user defined dictionary for better Chinese token segmentation
    if FLAGS.jieba_user_dict is not None:
        jieba.load_userdict(FLAGS.jieba_user_dict)

    # create directory if it doesn't exist
    if not FLAGS.do_single_predict:
        if not os.path.isdir(FLAGS.output_dir):
            os.mkdir(FLAGS.output_dir)

    if FLAGS.do_word2vec:

        word2vec_flags = {
            "sg": FLAGS.word2vec_sg,
            "size": FLAGS.word2vec_size,
            "window": FLAGS.word2vec_window,
            "min_count": FLAGS.word2vec_min_count,
            "number_of_threads": FLAGS.num_threads,
            "iterations": FLAGS.word2vec_iterations,
        }

        with open(os.path.join(FLAGS.output_dir, "WORD2VEC_FLAGS.json"),
                  'w') as json_file:
            json.dump(word2vec_flags, json_file, indent=4)

        train_corpus = load_data_file(FLAGS.train_word2vec_file, FLAGS)
        print("Cleaning training corpus...")
        train_corpus["TEXT"] = text_cleaning(
            text_array=train_corpus["TEXT"],
            bert_vocab_file=FLAGS.bert_vocab_file,
            do_lower_case=False,
            word_tokenization=FLAGS.word_tokenization,
            num_threads=FLAGS.num_threads)

        word2vec_model, df_vectors = train_word2vec(
            text_array=train_corpus['TEXT'],
            sg=FLAGS.word2vec_sg,
            size=FLAGS.word2vec_size,
            window=FLAGS.word2vec_window,
            min_count=FLAGS.word2vec_min_count,
            num_threads=FLAGS.num_threads,
            iterations=FLAGS.word2vec_iterations)

        print('Saving model and embedding vectors...')
        word2vec_model.save(os.path.join(FLAGS.output_dir, 'Word2vec.model'))
        feather.write_dataframe(
            df_vectors,
            os.path.join(FLAGS.output_dir, 'EmbeddingVectors.feather'))

    else:

        emb_matrix = None
        pretrained_model = None

        # list of labels and several associate dictionaries
        class_table = pd.read_csv(FLAGS.class_file,
                                  dtype={
                                      'LABEL': str,
                                      'DESCRIPTION': str
                                  })
        labels = list(class_table["LABEL"])
        label_to_desc_map = dict(
            zip(class_table["LABEL"], class_table["DESCRIPTION"]))
        index_to_label_map = dict(enumerate(labels))
        if FLAGS.train_with_weights:
            if "WEIGHT" not in class_table.columns:
                raise ValueError(
                    "The column `WEIGHT` is missing in the {}.".format(
                        FLAGS.class_file))
            else:
                label_to_weight_map = dict(
                    zip(class_table["LABEL"], class_table["WEIGHT"]))

        if FLAGS.do_train:
            # embedding matrix
            emb_matrix = feather.read_dataframe(FLAGS.emb_matrix_file)
            vocab = list(emb_matrix.columns)
            vocab_dict = dict(zip(vocab, range(2, len(vocab) + 2)))

            emb_matrix = emb_matrix.T.values  # transpose the matrix then obtain the values
            emb_matrix = np.insert(
                emb_matrix,
                0,
                np.random.normal(size=[1, emb_matrix.shape[1]]),
                axis=0)  # embedding vectors for oov symbol
            emb_matrix = np.insert(
                emb_matrix,
                0,
                np.random.normal(size=[1, emb_matrix.shape[1]]),
                axis=0)  # embedding vectors for padding symbol

        else:
            # available in "do_eval", "do_predict", "do_single_predict" mode
            with open(FLAGS.emb_layer_vocab, "r") as f:
                vocab = [line.rstrip('\n') for line in f]
                vocab_dict = dict(zip(vocab, range(2, len(vocab) + 2)))

            with open(FLAGS.keras_model_flags, "r") as f:
                keras_model_flags = json.load(f)

            # bert_vocab_file is needed if word piece tokenization is used in training the model
            if keras_model_flags[
                    "word_tokenization"] == "word_piece" and FLAGS.bert_vocab_file is None:
                raise ValueError(
                    "The model uses Word Piece tokenization but `bert_vocab_file` is not supplied."
                )

            print("\nLoading trained model...")
            if keras_model_flags[
                    "loss_function"] == 'categorical_crossentropy':
                pretrained_model = load_model(FLAGS.keras_model_file)
            elif keras_model_flags["loss_function"] == 'focal_loss':
                pretrained_model = load_model(FLAGS.keras_model_file,
                                              custom_objects={
                                                  'focal_loss(alpha=2)':
                                                  focal_loss(alpha=2),
                                                  'focal_loss_fixed':
                                                  focal_loss()
                                              })

        # instantiate a GCNN model object
        model = GCNN(num_classes=len(labels),
                     max_len=FLAGS.max_len,
                     emb_matrix=emb_matrix,
                     num_filter_1=FLAGS.num_filter_1,
                     kernel_size_1=FLAGS.kernel_size_1,
                     num_stride_1=FLAGS.num_stride_1,
                     num_filter_2=FLAGS.num_filter_2,
                     kernel_size_2=FLAGS.kernel_size_2,
                     num_stride_2=FLAGS.num_stride_2,
                     hidden_dims_1=FLAGS.hidden_dims_1,
                     hidden_dims_2=FLAGS.hidden_dims_2,
                     hidden_dims_3=FLAGS.hidden_dims_3,
                     hidden_dims_4=FLAGS.hidden_dims_4,
                     dropout_rate=FLAGS.dropout_rate,
                     pretrained_model=pretrained_model)

    if FLAGS.do_train:

        # save flags and hyperparameters to a json file
        print("Saving model flags and hyperparameters to {}".format(
            os.path.join(FLAGS.output_dir, FLAGS.model_name + "_FLAGS.json")))

        neural_network_flags = {
            "model_name": FLAGS.model_name,
            "num_classes": len(labels),
            "max_len": FLAGS.max_len,
            "batch_size": FLAGS.batch_size,
            "num_filter_1": FLAGS.num_filter_1,
            "kernel_size_1": FLAGS.kernel_size_1,
            "num_stride_1": FLAGS.num_stride_1,
            "num_filter_2": FLAGS.num_filter_2,
            "kernel_size_2": FLAGS.kernel_size_2,
            "num_stride_2": FLAGS.num_stride_2,
            "hidden_dims_1": FLAGS.hidden_dims_1,
            "hidden_dims_2": FLAGS.hidden_dims_2,
            "hidden_dims_3": FLAGS.hidden_dims_3,
            "hidden_dims_4": FLAGS.hidden_dims_4,
            "epochs": FLAGS.epochs,
            "learning_rate": FLAGS.learning_rate,
            "es_patience": FLAGS.es_patience,
            "es_min_delta": FLAGS.es_min_delta,
            "reduce_lr_patience": FLAGS.reduce_lr_patience,
            "reduce_lr_factor": FLAGS.reduce_lr_factor,
            "dropout_rate": FLAGS.dropout_rate,
            "loss_function": FLAGS.loss_function,
            "train_with_weights": FLAGS.train_with_weights,
            "word_tokenization": FLAGS.word_tokenization
        }

        with open(
                os.path.join(FLAGS.output_dir,
                             FLAGS.model_name + "_flags.json"),
                'w') as json_file:
            json.dump(neural_network_flags, json_file, indent=4)

        with open(
                os.path.join(FLAGS.output_dir,
                             FLAGS.model_name + "_vocab.txt"), "w") as f:
            f.writelines('\n'.join(vocab))

        # import data and text tokenization
        print("Importing data...")
        train = load_data_file(os.path.join(FLAGS.train_file), FLAGS)
        valid = load_data_file(os.path.join(FLAGS.valid_file), FLAGS)

        # separate records with valid and invalid labels into 2 dataframes
        train, invalid_label_train = label_validation(
            train, "LABEL", labels, num_threads=FLAGS.num_threads)
        valid, invalid_label_valid = label_validation(
            valid, "LABEL", labels, num_threads=FLAGS.num_threads)

        # save the records with invalid labels
        if invalid_label_train.shape[0] > 0:
            print("{} train samples have invalid labels.".format(
                invalid_label_train.shape[0]))
            feather.write_dataframe(
                invalid_label_train,
                os.path.join(FLAGS.output_dir, "invalid_label_train.feather"))
        if invalid_label_valid.shape[0] > 0:
            print("{} validation samples have invalid labels.".format(
                invalid_label_train.shape[0]))
            feather.write_dataframe(
                invalid_label_valid,
                os.path.join(FLAGS.output_dir, "invalid_label_valid.feather"))

        print("Cleaning TRAINING data...")
        train["TEXT"] = text_cleaning(
            text_array=train["TEXT"],
            bert_vocab_file=FLAGS.bert_vocab_file,
            do_lower_case=False,
            word_tokenization=FLAGS.word_tokenization,
            num_threads=FLAGS.num_threads)

        print("Cleaning VALIDATION data...")
        valid["TEXT"] = text_cleaning(
            text_array=valid["TEXT"],
            bert_vocab_file=FLAGS.bert_vocab_file,
            do_lower_case=False,
            word_tokenization=FLAGS.word_tokenization,
            num_threads=FLAGS.num_threads)

        # Convert tokenized text to numerical matrices for Neural Network training
        print("Transforming data...")
        x_train, y_train, weight_train, informative_train, \
        uninformative_train = get_wxy(des_lists=train["TEXT"],
                                      label_lists=train["LABEL"],
                                      vocab_dict=vocab_dict,
                                      max_len=FLAGS.max_len,
                                      num_threads=FLAGS.num_threads,
                                      labels=labels,
                                      label_to_weight_map=label_to_weight_map)

        x_valid, y_valid, weight_valid, informative_valid, \
        uninformative_valid = get_wxy(des_lists=valid["TEXT"],
                                      label_lists=valid["LABEL"],
                                      vocab_dict=vocab_dict,
                                      max_len=FLAGS.max_len,
                                      num_threads=FLAGS.num_threads,
                                      labels=labels,
                                      label_to_weight_map=label_to_weight_map)

        # save records with uninformative description
        if sum(uninformative_train) > 0:
            print(
                "{} train samples are considered uninformative and are not included in model training"
                .format(sum(uninformative_train)))
            feather.write_dataframe(
                train.loc[uninformative_train, :],
                os.path.join(FLAGS.output_dir, "uninformative_train.feather"))

        if sum(uninformative_valid) > 0:
            print(
                "{} validation samples are considered uninformative and are not included in model training"
                .format(sum(uninformative_valid)))
            feather.write_dataframe(
                valid.loc[uninformative_valid, :],
                os.path.join(FLAGS.output_dir, "uninformative_valid.feather"))

        print("Building Neural Network...")
        # Build the GCNN model
        model.build()
        print("Structure of the Neural Network:")
        model.show_summary()
        print("Training begins...")
        model.train(x_train, y_train, x_valid, y_valid, FLAGS.learning_rate,
                    FLAGS.es_patience, FLAGS.es_min_delta,
                    FLAGS.reduce_lr_patience, FLAGS.reduce_lr_factor,
                    FLAGS.batch_size, FLAGS.epochs, FLAGS.loss_function,
                    weight_train, weight_valid)

        # Evaluating model performance on training and validation data
        print("Evaluating model performance...")
        print("TRAINING DATA:")
        model.evaluate(x_train, y_train)
        print("VALIDATION DATA:")
        model.evaluate(x_valid, y_valid)

        # save GCNN model
        print("Training completed! Saving model to {}".format(
            os.path.join(FLAGS.output_dir, FLAGS.model_name + ".h5")))
        model.save(os.path.join(FLAGS.output_dir,
                                FLAGS.model_name + ".h5"))  # save model

    if FLAGS.do_eval:
        print("Begin evaluating model performance...")

        # import data and tokenize text
        print("Importing data...")
        data_eval = load_data_file(FLAGS.eval_file, FLAGS)

        # separate records with valid and invalid labels into 2 dataframes
        data_eval, invalid_label_data_eval = label_validation(
            data_eval, "LABEL", labels, num_threads=FLAGS.num_threads)

        # save the records with invalid labels
        if invalid_label_data_eval.shape[0] > 0:
            print("{} evaluation samples have invalid labels.".format(
                invalid_label_data_eval.shape[0]))
            feather.write_dataframe(
                invalid_label_data_eval,
                os.path.join(FLAGS.output_dir, "invalid_label_eval.feather"))

        print("Cleaning data...")
        cleaned_text_array = text_cleaning(
            text_array=data_eval["TEXT"],
            bert_vocab_file=FLAGS.bert_vocab_file,
            do_lower_case=False,
            word_tokenization=keras_model_flags["word_tokenization"],
            num_threads=FLAGS.num_threads)

        # import data and text tokenization
        print("Transforming data...")
        x_eval, y_eval, _, informative_eval, uninformative_eval = get_wxy(
            des_lists=cleaned_text_array,
            label_lists=data_eval["LABEL"],
            vocab_dict=vocab_dict,
            max_len=keras_model_flags["max_len"],
            labels=labels,
            num_threads=FLAGS.num_threads)

        # save records with uninformative description
        if sum(uninformative_eval) > 0:
            print(
                "{} evaluation samples are considered uninformative and are not included in model evaluation."
                .format(sum(uninformative_eval)))
            feather.write_dataframe(
                data_eval.loc[uninformative_eval, :],
                os.path.join(FLAGS.output_dir, "uninformative_eval.feather"))

        print("Prediction in progress...")
        pred_probits_eval, pred_labels_eval = model.predict(
            x_eval, index_to_label_map)
        print("Prediction completed.")

        data_eval = data_eval.loc[informative_eval, :]
        data_eval["PREDICTION"] = pred_labels_eval
        data_eval["PRED_PROBIT"] = np.max(pred_probits_eval, axis=1)
        data_eval = data_eval.join(
            other=pd.DataFrame(pred_probits_eval, columns=labels))

        feather.write_dataframe(
            data_eval, os.path.join(FLAGS.output_dir, "evaluation.feather"))

        # compute precision, recall and f-1 score
        data_report = classification_report(data_eval["LABEL"],
                                            pred_labels_eval,
                                            output_dict=True)
        data_report = pd.DataFrame(data_report).transpose()
        data_report.to_csv(os.path.join(FLAGS.output_dir, "eval_metrics.csv"))

        # compute distribution and classification rates of the highest prediction probability
        distr_cr = get_cr_distr(data_eval,
                                col_label="LABEL",
                                col_pred="PREDICTION",
                                col_probit="PRED_PROBIT")
        distr_cr.to_csv(os.path.join(FLAGS.output_dir,
                                     "distribution_stat.csv"))

        print("\nPerformance Metrics:")
        print(data_report)
        print("\nDistribution of classification rate: ")
        print(distr_cr)

    if FLAGS.do_predict:
        print("Prediction on input dataset...")

        # import data and tokenize text
        print("Importing data...")
        data_prediction = load_data_file(FLAGS.predict_file, FLAGS)

        print("Cleaning input data...")
        cleaned_text_array = text_cleaning(
            text_array=data_prediction["TEXT"],
            bert_vocab_file=FLAGS.bert_vocab_file,
            do_lower_case=False,
            word_tokenization=keras_model_flags["word_tokenization"],
            num_threads=FLAGS.num_threads)

        # Convert tokenized text to numerical matrices for Neural Network training
        print("Transforming data...")
        x_test, _, _, informative_test, uninformative_predict = get_wxy(
            des_lists=cleaned_text_array,
            vocab_dict=vocab_dict,
            max_len=keras_model_flags["max_len"],
            labels=labels,
            num_threads=FLAGS.num_threads)

        if sum(uninformative_predict) > 0:
            print(
                "{} prediction samples are considered uninformative and are not included in prediction."
                .format(sum(uninformative_predict)))

            feather.write_dataframe(
                data_prediction.loc[uninformative_predict, :],
                os.path.join(FLAGS.output_dir,
                             "uninformative_predict.feather"))

        # prediction
        print("Prediction in progress...")
        pred_probits, pred_labels = model.predict(x_test, index_to_label_map)
        print("Prediction completed.")

        data_prediction = data_prediction.loc[informative_test, :]
        data_prediction["PREDICTION"] = pred_labels
        data_prediction["PRED_PROBIT"] = np.max(pred_probits, axis=1)
        data_prediction = data_prediction.join(
            other=pd.DataFrame(pred_probits, columns=labels))

        feather.write_dataframe(
            data_prediction,
            os.path.join(FLAGS.output_dir, "prediction.feather"))

    if FLAGS.do_single_predict:

        print("Time taken for loading the model: {:.2f}".format(time.time() -
                                                                main_start))
        do_next = True  # boolean variable of whether to do next prediction.

        while do_next:  # while loop to enable users to do multiple queries
            input_text = input("\nInput text here: ")
            start = time.time()
            input_text = text_cleaning(
                text_array=[input_text],
                bert_vocab_file=FLAGS.bert_vocab_file,
                do_lower_case=False,
                word_tokenization=keras_model_flags["word_tokenization"],
                num_threads=FLAGS.num_threads)
            x_data, _, _, _, uninformative_data = get_wxy(
                des_lists=input_text,
                vocab_dict=vocab_dict,
                max_len=keras_model_flags["max_len"],
                labels=labels)

            if uninformative_data[0]:
                print(
                    "\nThe text is either uninformative or never learnt by model before..."
                )
            else:
                print("Prediction in progress...")
                pred_probits, pred_labels = model.predict(
                    x_data, index_to_label_map)

                # sort the pairs of label and prediction probability in descending order
                output = dict(zip(labels, pred_probits[0]))
                sorted_output = sorted(output.items(),
                                       key=lambda kv: kv[1],
                                       reverse=True)

                print("\nPREDICTIONS:")
                # print the top 3 predictions and probabilities
                index = 1
                for (code, probit) in sorted_output[0:3]:
                    code_name = label_to_desc_map[code]
                    print(str(index) + ". " + code_name)
                    print("   {:.2%}".format(probit))
                    index += 1

            print("Time taken for prediction: {:.2f}".format(time.time() -
                                                             start))
            del input_text

            # ask user whether he/she would like to continue prediction
            reply = input("\nDo you want to start a new prediction (Y/N)? ")
            while reply != "Y" and reply != "N":
                reply = input("Please input either Y or N to proceed: ")

            # end the program if the user answer "N".
            if reply == "N":
                do_next = False
                print("****************End of program****************")