Esempio n. 1
0
def vis(args):
    colors_s = {0: 'red', 1: 'blue', 2: 'green', 3: 'black'}
    colors_t = {0: 'lightcoral', 1: 'lightskyblue', 2: 'lightgreen', 3: 'gray'}
    categories = {0: 'N', 1: 'V', 2: 'S', 3: 'F'}
    metrics = {'Cosine': 'cosine', 'L2': 'euclidean'}

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET
    if args.target:
        target = args.target

    batch_size = cfg.TRAIN.BATCH_SIZE
    distance = args.distance if args.distance else cfg.SETTING.DISTANCE

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)

    img_path = os.path.join('./figures', exp_id)
    if not os.path.exists(img_path):
        os.makedirs(img_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target)
    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            cfg.SETTING.ASPP_BN,
                            cfg.SETTING.ASPP_ACT,
                            cfg.SETTING.LEAD,
                            cfg.PARAMETERS.P,
                            cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
    net.cuda()
    net.eval()

    dataset_source = MULTI_ECG_EVAL_DATASET(source,
                                            load_beat_with_rr,
                                            data_dict_source,
                                            test_records=source_records,
                                            beat_num=cfg.SETTING.BEAT_NUM,
                                            fixed_len=cfg.SETTING.FIXED_LEN,
                                            lead=cfg.SETTING.LEAD)
    dataloader_source = DataLoader(
        dataset_source,
        batch_size=batch_size,
        num_workers=cfg.SYSTEM.NUM_WORKERS,
        sampler=ImbalancedDatasetSampler(dataset_source))

    dataset_target = MULTI_ECG_EVAL_DATASET(target,
                                            load_beat_with_rr,
                                            data_dict_target,
                                            test_records=target_records,
                                            beat_num=cfg.SETTING.BEAT_NUM,
                                            fixed_len=cfg.SETTING.FIXED_LEN,
                                            lead=cfg.SETTING.LEAD)
    dataloader_target = DataLoader(
        dataset_target,
        batch_size=batch_size,
        num_workers=cfg.SYSTEM.NUM_WORKERS,
        sampler=ImbalancedDatasetSampler(dataset_target))

    features_source = []
    features_target = []
    labels_source = []
    labels_target = []

    source_logits = []
    target_logits = []
    source_probs = []
    target_probs = []

    raw_signals_source = []
    raw_signals_target = []

    with torch.no_grad():

        tsne = manifold.TSNE(n_components=2,
                             metric=metrics[distance],
                             perplexity=30,
                             early_exaggeration=4.0,
                             learning_rate=500.0,
                             n_iter=2000,
                             init='pca',
                             random_state=2389)

        for idb, data_batch in enumerate(dataloader_source):
            s_batch, l_batch = data_batch
            s_batch_cpu = s_batch.detach().numpy()
            s_batch = s_batch.cuda()

            features_s, logits_s = net(s_batch)

            # feats = net.get_feature_maps(s_batch)
            # feats = feats.detach().cpu().numpy()
            # plt.figure(figsize=(12.5, 10))
            # plt.plot(feats[0])
            # plt.savefig(osp.join(img_path, '{}.png'.format(idb)), bbox_inches='tight')
            # plt.close()

            feat_s_cpu = features_s.detach().cpu().numpy()
            logits_s_cpu = logits_s.detach().cpu().numpy()
            probs_s = F.log_softmax(logits_s,
                                    dim=1).exp().detach().cpu().numpy()

            source_logits.append(logits_s_cpu)
            source_probs.append(probs_s)
            features_source.append(feat_s_cpu)
            raw_signals_source.append(s_batch_cpu)
            labels_source.append(l_batch)

            if idb == args.epochs - 1:
                break

        for idb, data_batch in enumerate(dataloader_target):
            s_batch, l_batch = data_batch
            s_batch_cpu = s_batch.detach().numpy()
            s_batch = s_batch.cuda()

            features_t, logits_t = net(s_batch)

            feat_t_cpu = features_t.detach().cpu().numpy()
            logits_t_cpu = logits_t.detach().cpu().numpy()
            probs_t = F.log_softmax(logits_t,
                                    dim=1).exp().detach().cpu().numpy()

            target_logits.append(logits_t_cpu)
            target_probs.append(probs_t)
            features_target.append(feat_t_cpu)
            raw_signals_target.append(s_batch_cpu)
            labels_target.append(l_batch)

            if idb == args.epochs - 1:
                break

        labels_source = np.concatenate(labels_source, axis=0)
        labels_target = np.concatenate(labels_target, axis=0)

        # target_probs = np.concatenate(target_probs, axis=0)
        # preds_t = np.argmax(target_probs, axis=1)
        # probs_t = np.max(target_probs, axis=1)
        #
        # for l in range(4):
        #     indices_tl = np.argwhere(preds_t == l)
        #     if len(indices_tl) > 0:
        #         indices_tl = indices_tl.squeeze(axis=1)
        #         probs_tl = probs_t[indices_tl]
        #         gt_tl = labels_target[indices_tl]
        #         indices_l = np.where(gt_tl == l, 1, 0)
        #
        #         plt.figure(figsize=(20, 15))
        #         n, bins, patches = plt.hist(probs_tl, bins=300)
        #         plt.savefig(osp.join(img_path, 'cls_{}.png'.format(l)), bbox_inches='tight')
        #         plt.close()
        #
        #         corr_indices_l = np.argwhere(indices_l == 1)
        #         incorr_indices_l = np.argwhere(indices_l == 0)
        #
        #         if len(corr_indices_l):
        #             plt.figure(figsize=(20, 15))
        #             corr_indices_l = corr_indices_l.squeeze(axis=1)
        #             corr_probs_tl = probs_tl[corr_indices_l]
        #             _, _, _ = plt.hist(corr_probs_tl, bins=300)
        #             plt.savefig(osp.join(img_path, 'corr_cls{}.png'.format(l)), bbox_inches='tight')
        #             plt.close()
        #         if len(incorr_indices_l):
        #             plt.figure(figsize=(20, 15))
        #             incorr_indices_l = incorr_indices_l.squeeze(axis=1)
        #             incorr_probs_tl = probs_tl[incorr_indices_l]
        #             _, _, _ = plt.hist(incorr_probs_tl, bins=300, color='red')
        #             plt.savefig(osp.join(img_path, 'incorr_cls{}.png'.format(l)), bbox_inches='tight')
        #             plt.close()

        labels = np.concatenate([labels_source, labels_target], axis=0)

        count_source = {'N': 0, 'V': 0, 'S': 0, 'F': 0}
        count_target = {'N': 0, 'V': 0, 'S': 0, 'F': 0}
        keys = ['N', 'V', 'S', 'F']

        num_source = len(labels_source)
        num_target = len(labels_target)

        for i in range(num_source):
            count_source[keys[labels_source[i]]] += 1
        for j in range(num_target):
            count_target[keys[labels_target[j]]] += 1

        for k in keys:
            print('The number of {} in source: {}; in target: {}'.format(
                k, count_source[k], count_target[k]))

        features_source = np.concatenate(features_source, axis=0)
        features_target = np.concatenate(features_target, axis=0)

        features = np.concatenate([features_source, features_target], axis=0)
        feat_tsne = tsne.fit_transform(features)

        x_min, x_max = feat_tsne.min(0), feat_tsne.max(0)
        feat_norm = (feat_tsne - x_min) / (x_max - x_min)

        feat_norm_s = feat_norm[0:num_source]
        feat_norm_t = feat_norm[num_source:num_target + num_source]

        features_s_dict = {}
        feat_norm_s_dict = {}
        features_t_dict = {}
        feat_norm_t_dict = {}
        for l in range(4):
            l_indices = np.argwhere(labels_source == l).squeeze(axis=1)
            features_s_dict[l] = features_source[l_indices]
            feat_norm_s_dict[l] = feat_norm_s[l_indices]

            l_indices_t = np.argwhere(labels_target == l).squeeze(axis=1)

            features_t_dict[l] = features_target[l_indices_t]
            feat_norm_t_dict[l] = feat_norm_t[l_indices_t]
        '''The feature visualization'''
        # plt.figure(figsize=(30, 15))
        # for i in range(features_source.shape[0]):
        #     if labels_source[i] == 0:
        #         plt.subplot(411)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        #     elif labels_source[i] == 1:
        #         plt.subplot(412)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        #     elif labels_source[i] == 2:
        #         plt.subplot(413)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        #     else:
        #         plt.subplot(414)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        # img_save_path = osp.join(img_path, 'feat_s_{}_{}.png'.format(exp_id, args.check_epoch))
        # plt.savefig(img_save_path, bbox_inches='tight')
        # plt.close()
        #
        # plt.figure(figsize=(30, 15))
        # for i in range(features_target.shape[0]):
        #     if labels_target[i] == 0:
        #         plt.subplot(411)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        #     elif labels_target[i] == 1:
        #         plt.subplot(412)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        #     elif labels_target[i] == 2:
        #         plt.subplot(413)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        #     else:
        #         plt.subplot(414)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        # img_save_path = osp.join(img_path, 'feat_t_{}_{}.png'.format(exp_id, args.check_epoch))
        # plt.savefig(img_save_path, bbox_inches='tight')
        # plt.close()
        '''The class-specific view'''

        if 'mitdb' in target:
            plt.figure(figsize=(20, 20))
            for l in range(4):
                plt.scatter(feat_norm_s_dict[l][:, 0],
                            feat_norm_s_dict[l][:, 1],
                            marker='o',
                            color=colors_s[l],
                            label='source {}'.format(categories[l]))
                plt.scatter(feat_norm_t_dict[l][:, 0],
                            feat_norm_t_dict[l][:, 1],
                            marker='X',
                            color=colors_t[l],
                            label='target {}'.format(categories[l]))
            plt.xticks([])
            plt.yticks([])
            plt.legend(loc='upper right', fontsize=30)
            img_save_path = osp.join(
                img_path, 'tsne_{}_{}_cls.png'.format(exp_id,
                                                      args.check_epoch))
            plt.savefig(img_save_path, bbox_inches='tight')
            plt.close()

        else:
            plt.figure(figsize=(20, 20))
            for l in range(3):
                plt.scatter(feat_norm_s_dict[l][:, 0],
                            feat_norm_s_dict[l][:, 1],
                            marker='o',
                            color=colors_s[l],
                            label='source {}'.format(categories[l]))
                plt.scatter(feat_norm_t_dict[l][:, 0],
                            feat_norm_t_dict[l][:, 1],
                            marker='X',
                            color=colors_t[l],
                            label='target {}'.format(categories[l]))
            plt.xticks([])
            plt.yticks([])
            plt.legend(loc='upper right', fontsize=30)
            img_save_path = osp.join(
                img_path, 'tsne_{}_{}_cls.png'.format(exp_id,
                                                      args.check_epoch))
            plt.savefig(img_save_path, bbox_inches='tight')
            plt.close()
        '''The domain-specific view'''
Esempio n. 2
0
def train(args):

    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()
    '''Setting the random seed used in the experiment'''

    if cfg.SETTING.SEED != -1:
        torch.manual_seed(cfg.SETTING.SEED)
        torch.cuda.manual_seed(cfg.SETTING.SEED)

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET

    batch_size = cfg.TRAIN.BATCH_SIZE
    pre_train_epochs = cfg.TRAIN.PRE_TRAIN_EPOCHS
    epochs = cfg.TRAIN.EPOCHS
    lr = cfg.TRAIN.LR
    decay_rate = cfg.TRAIN.DECAY_RATE
    decay_step = cfg.TRAIN.DECAY_STEP
    flag_intra = cfg.SETTING.INTRA_LOSS
    flag_inter = cfg.SETTING.INTER_LOSS
    flag_norm = cfg.SETTING.NORM_ALIGN
    optimizer_ = cfg.SETTING.OPTIMIZER

    w_l2 = cfg.PARAMETERS.W_L2
    w_cls = cfg.PARAMETERS.W_CLS
    w_norm = cfg.PARAMETERS.W_NORM
    w_cs = cfg.PARAMETERS.BETA1
    w_ct = cfg.PARAMETERS.BETA2
    w_cst = cfg.PARAMETERS.BETA
    w_mmd = cfg.PARAMETERS.BETA_MMD
    w_inter = cfg.PARAMETERS.BETA_INTER
    w_intra = cfg.PARAMETERS.BETA_INTRA
    thr_m = cfg.PARAMETERS.THR_M
    thrs_ = cfg.PARAMETERS.THRS
    entropy_w = 0.001

    emsemble_num = cfg.PARAMETERS.EMSEMBLE_NUM
    emsemble_step = cfg.PARAMETERS.EMSEMBLE_STEP

    lr_c = cfg.PARAMETERS.LR_C
    lr_cs = cfg.PARAMETERS.LR_C_S
    lr_ct = cfg.PARAMETERS.LR_C_T

    thrs = {}
    for l in range(len(thrs_)):
        thrs[l] = thrs_[l]

    exp_id = os.path.basename(cfg_dir).split('.')[0]

    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)
    if not osp.exists(save_path):
        os.makedirs(save_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))
    flag_loading = True if osp.exists(check_point_dir) else False

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target) if (
        source != target) else data_dict_source

    transform = augmentation_transform_with_rr if cfg.SETTING.AUGMENTATION else None

    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    dataset = UDA_DATASET(source,
                          target,
                          data_dict_source,
                          data_dict_target,
                          source_records,
                          target_records,
                          cfg.SETTING.UDA_NUM,
                          load_beat_with_rr,
                          transform=transform,
                          beat_num=cfg.SETTING.BEAT_NUM,
                          fixed_len=cfg.SETTING.FIXED_LEN,
                          lead=cfg.SETTING.LEAD)

    dset_val = MULTI_ECG_EVAL_DATASET(target,
                                      load_beat_with_rr,
                                      data_dict_target,
                                      test_records=target_records,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD,
                                      unlabel_num=0)
    dloader_val = DataLoader(dset_val,
                             batch_size=batch_size,
                             num_workers=cfg.SYSTEM.NUM_WORKERS)

    if cfg.TRAIN.IMBALANCE_SAMPLE:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                sampler=UDAImbalancedDatasetSampler(dataset))
    else:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                shuffle=True)

    iter_num = int(len(dataset) / batch_size)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            cfg.SETTING.ASPP_BN,
                            cfg.SETTING.ASPP_ACT,
                            cfg.SETTING.LEAD,
                            cfg.PARAMETERS.P,
                            cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    # Initialization of the model
    net.apply(init_weights)

    teacher_net = build_acnn_models(cfg.SETTING.NETWORK,
                                    cfg.SETTING.ASPP_BN,
                                    cfg.SETTING.ASPP_ACT,
                                    cfg.SETTING.LEAD,
                                    cfg.PARAMETERS.P,
                                    cfg.SETTING.DILATIONS,
                                    act_func=cfg.SETTING.ACT,
                                    f_act_func=cfg.SETTING.F_ACT,
                                    apply_residual=cfg.SETTING.RESIDUAL,
                                    bank_num=cfg.SETTING.BANK_NUM)

    print("The network {} has {} parameters in total".format(
        cfg.SETTING.NETWORK, sum(x.numel() for x in net.parameters())))

    if flag_loading:
        net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
        print("The saved model is loaded.")
    net = net.cuda()

    criterion_cls_4 = loss_function(cfg.SETTING.LOSS,
                                    dataset=source,
                                    num_ew=cfg.PARAMETERS.N,
                                    T=cfg.PARAMETERS.T)
    criterion_dist = build_distance(cfg.SETTING.DISTANCE)

    optimizer_pre = get_optimizer(optimizer_, net.parameters(), lr, w_l2)
    scheduler_pre = optim.lr_scheduler.StepLR(optimizer_pre,
                                              step_size=decay_step,
                                              gamma=decay_rate)

    optimizer_main = get_optimizer(optimizer_, net.parameters(), lr * 0.1,
                                   w_l2)
    scheduler_main = optim.lr_scheduler.StepLR(optimizer_main,
                                               step_size=decay_step * 10,
                                               gamma=decay_rate)
    evaluator = Eval(num_class=4)
    '''Initial and register the EMA'''
    ema = EMA(model=net, decay=0.99)
    ema.register()

    if check_epoch < pre_train_epochs - 1:
        print("Starting STAGE I: pre-training the model using source data")

        best_f1_s = 0.0

        for epoch in range(max(0, check_epoch), pre_train_epochs):
            for idb, data_batch in enumerate(dataloader):
                net.train()

                s_batch, sl_batch, t_batch, tl_batch = data_batch
                s_batch = s_batch.cuda()
                sl_batch = sl_batch.cuda()
                t_batch = t_batch.cuda()
                tl_batch = tl_batch.cuda()

                _, preds = net(s_batch)
                loss = criterion_cls_4(preds, sl_batch)

                # Add an entropy regularizer
                # p_softmax = nn.Softmax(dim=1)(preds)
                # loss -= get_entropy_loss(p_softmax, entropy_w)

                optimizer_pre.zero_grad()
                loss.backward()
                optimizer_pre.step()
                scheduler_pre.step()
                if args.use_ema:
                    ema.update()
                    ema.apply_shadow()

                running_lr = optimizer_pre.state_dict(
                )['param_groups'][0]['lr']

                print("[{}, {}] cls loss: {:.4f}, lr: {:.4f}".format(
                    epoch, idb, loss, running_lr),
                      end='\r')
                if idb == iter_num - 1:
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, '{}.pkl'.format(epoch)))

            if epoch % 10 == 9:
                net.eval()
                preds_entire = []
                labels_entire = []

                with torch.no_grad():
                    for idb, data_batch in enumerate(dloader_val):
                        s_batch, l_batch = data_batch

                        s_batch = s_batch.cuda()
                        l_batch = l_batch.numpy()

                        _, logits = net(s_batch)

                        preds_softmax = F.log_softmax(logits, dim=1).exp()
                        preds_softmax_np = preds_softmax.detach().cpu().numpy()
                        preds = np.argmax(preds_softmax_np, axis=1)

                        preds_entire.append(preds)
                        labels_entire.append(l_batch)

                        torch.cuda.empty_cache()

                preds_entire = np.concatenate(preds_entire, axis=0)
                labels_entire = np.concatenate(labels_entire, axis=0)

                Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                                   y_label=labels_entire)
                results = evaluator._metrics(predictions=preds_entire,
                                             labels=labels_entire)
                # con_matrix = evaluator._confusion_matrix(y_pred=preds_entire,
                #                                          y_label=labels_entire)

                f1_scores = evaluator._f1_score(y_pred=preds_entire,
                                                y_true=labels_entire)

                print('The overall accuracy is: {}'.format(results['Acc']))
                print("The confusion matrix is: ")
                print('Pp: ')
                pprint.pprint(Pp)
                print('Se: ')
                pprint.pprint(Se)
                print('The F1 score is: {}'.format(f1_scores))

                if f1_scores[2] >= best_f1_s:
                    best_f1_s = f1_scores[2]
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, 'best_model.pkl'))

                torch.cuda.empty_cache()

    print('Start obtaining centers of each cluster and distribution')

    best_model_dir = osp.join(save_path, 'best_model.pkl')
    net.load_state_dict(torch.load(best_model_dir)['model_state_dict'])

    centers_source_dir = osp.join(save_path, "centers_source.mat")
    centers_target_dir = osp.join(save_path, "centers_target.mat")
    flag_centers = osp.exists(centers_source_dir) and osp.exists(
        centers_target_dir)

    center_source_dir = osp.join(save_path, "center_s.mat")
    center_target_dir = osp.join(save_path, "center_t.mat")
    flag_center = osp.exists(center_source_dir) and osp.exists(
        center_target_dir)

    if flag_centers:
        centers_s_ = sio.loadmat(centers_source_dir)
        centers_t_ = sio.loadmat(centers_target_dir)
        centers_s = {}
        centers_t = {}
        for l in range(4):
            if 'c{}'.format(l) in centers_s_.keys():
                centers_s[l] = torch.from_numpy(
                    centers_s_['c{}'.format(l)].squeeze()).cuda()
            if 'c{}'.format(l) in centers_t_.keys():
                centers_t[l] = torch.from_numpy(
                    centers_t_['c{}'.format(l)].squeeze()).cuda()
    else:
        net.eval()
        centers_s, counter_s = init_source_centers(
            net,
            source,
            source_records,
            data_dict_source,
            batch_size=batch_size,
            num_workers=cfg.SYSTEM.NUM_WORKERS,
            beat_num=cfg.SETTING.BEAT_NUM,
            fixed_len=cfg.SETTING.FIXED_LEN,
            lead=cfg.SETTING.LEAD)
        centers_t, counter_t = init_target_centers(
            net,
            target,
            target_records,
            data_dict_target,
            batch_size=batch_size,
            num_workers=cfg.SYSTEM.NUM_WORKERS,
            beat_num=cfg.SETTING.BEAT_NUM,
            fixed_len=cfg.SETTING.FIXED_LEN,
            lead=cfg.SETTING.LEAD,
            thrs=thrs)
        centers_s_np = {}
        centers_t_np = {}
        for l in range(4):
            if l in centers_s.keys():
                centers_s_np['c{}'.format(
                    l)] = centers_s[l].detach().cpu().numpy()
            if l in centers_t.keys():
                centers_t_np['c{}'.format(
                    l)] = centers_t[l].detach().cpu().numpy()
        sio.savemat(centers_source_dir, centers_s_np)
        sio.savemat(centers_target_dir, centers_t_np)

    if flag_center:
        center_s = torch.from_numpy(
            sio.loadmat(center_source_dir)['c'].squeeze()).cuda()
        center_t = torch.from_numpy(
            sio.loadmat(center_target_dir)['c'].squeeze()).cuda()
    else:

        center_s = init_entire_center(net,
                                      source,
                                      source_records,
                                      data_dict_source,
                                      batch_size=batch_size,
                                      num_workers=cfg.SYSTEM.NUM_WORKERS,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD)
        center_t = init_entire_center(net,
                                      target,
                                      target_records,
                                      data_dict_target,
                                      batch_size=batch_size,
                                      num_workers=cfg.SYSTEM.NUM_WORKERS,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD)
        sio.savemat(center_source_dir, {'c': center_s.detach().cpu().numpy()})
        sio.savemat(center_target_dir, {'c': center_t.detach().cpu().numpy()})

    print("Starting STAGE III: adaptation process")

    low_bound = max(2 * pre_train_epochs,
                    check_epoch) if cfg.SETTING.RE_TRAIN else max(
                        pre_train_epochs, check_epoch)
    high_bound = 2 * pre_train_epochs + epochs if cfg.SETTING.RE_TRAIN else pre_train_epochs + epochs
    for epoch in range(low_bound, high_bound):
        best_f1_s = 0.0
        dataset.shuffle_target()

        # load_dir = osp.join(save_path, 'best_model.pkl')
        # loaded_models = torch.load(load_dir)['model_state_dict']
        loaded_models = []
        for idx in range(emsemble_num):
            load_dir = osp.join(
                save_path, '{}.pkl'.format(epoch - idx * emsemble_step - 1))
            loaded_models.append(torch.load(load_dir)['model_state_dict'])

        for idb, data_batch in enumerate(dataloader):
            net.train()

            s_batch, sl_batch, t_batch, tl_batch = data_batch
            s_batch = s_batch.cuda()
            sl_batch = sl_batch.cuda()
            t_batch = t_batch.cuda()
            tl_batch = tl_batch.cuda()

            feat_s, preds_s = net(s_batch)
            feat_t, preds_t = net(t_batch)

            loss_cls = criterion_cls_4(preds_s, sl_batch)
            loss = loss_cls * w_cls

            # Add an entropy regularizer
            # p_softmax = nn.Softmax(dim=1)(preds_s)
            # loss -= get_entropy_loss(p_softmax, entropy_w)

            delta_s = center_s - torch.mean(feat_s, dim=0)
            delta_t = center_t - torch.mean(feat_t, dim=0)

            center_s = center_s - lr_c * delta_s
            center_t = center_t - lr_c * delta_t

            loss_mmd = criterion_dist(center_s, center_t)
            loss += loss_mmd * w_mmd

            loss_intra = 0
            loss_inter = 0
            loss_ct = 0
            loss_cs = 0
            loss_cst = 0

            if flag_norm:
                if cfg.SETTING.ALIGN_SET == 'soft':
                    loss += get_L2norm_loss_self_driven(feat_s, w_norm)
                    loss += get_L2norm_loss_self_driven(feat_t, w_norm)
                else:
                    loss += get_L2norm_loss_self_driven_hard(
                        feat_s, cfg.PARAMETERS.RADIUS, w_norm)
                    loss += get_L2norm_loss_self_driven_hard(
                        feat_t, cfg.PARAMETERS.RADIUS, w_norm)
            '''Obtaining the pesudo labels of target samples'''
            pseudo_label_nums = {0: 0, 1: 0, 2: 0, 3: 0}
            pseudo_labels, legal_indices = obtain_pseudo_labels(
                teacher_net, loaded_models, t_batch, thrs)
            # pseudo_labels: (NUM, ); legal_indices: (NUM, ),the indices of legal pseudo labels;

            tmp_centers_t = {}
            tmp_feats_t = {}
            # if len(pesudo_labels):
            if pseudo_labels.size(0) > 0:
                # feat_t_pesudo = torch.index_select(feat_t, dim=0, index=torch.LongTensor(legal_indices).cuda())
                feat_t_pseudo = torch.index_select(feat_t,
                                                   dim=0,
                                                   index=legal_indices)

                for l in range(4):
                    # _index = np.argwhere(pseudo_labels == l)
                    _index = torch.nonzero(pseudo_labels == l).squeeze(dim=1)
                    if _index.size(0) > 0:
                        pseudo_label_nums[l] = _index.size(0)
                        # _index = np.squeeze(_index, axis=1)
                        # _feat_t = torch.index_select(feat_t_pesudo, dim=0, index=torch.LongTensor(_index).cuda())
                        _feat_t = torch.index_select(feat_t_pseudo,
                                                     dim=0,
                                                     index=_index)
                        tmp_feats_t[l] = _feat_t
                        bs_ = _feat_t.size(0)

                        local_centers_tl = torch.mean(_feat_t, dim=0)
                        tmp_centers_t[l] = local_centers_tl

                        if l in centers_t.keys():
                            delta_ct = centers_t[l] - local_centers_tl
                            centers_t[l] = centers_t[l] - lr_ct * delta_ct
                            loss_ct_l = criterion_dist(local_centers_tl,
                                                       centers_t[l])
                            loss_ct += loss_ct_l
                        else:
                            centers_t[l] = local_centers_tl

                        if flag_intra:
                            m_feat_t = centers_t[l].repeat((bs_, 1))
                            loss_intra_l = criterion_dist(_feat_t,
                                                          m_feat_t,
                                                          dim=1)
                            loss_intra += loss_intra_l

            if cfg.SETTING.CLoss:
                loss += loss_ct * w_ct

            # sl_batch_np = sl_batch.detach().cpu().numpy()
            sl_batch_ = sl_batch.detach()
            true_label_nums = {0: 0, 1: 0, 2: 0, 3: 0}

            tmp_centers_s = {}
            tmp_feats_s = {}

            for l in range(4):
                # _index = np.argwhere(sl_batch_np == l)
                _index = torch.nonzero(sl_batch_ == l).squeeze(dim=1)
                if _index.size(0) > 0:
                    true_label_nums[l] = _index.size(0)
                    # _feat_s = torch.index_select(feat_s, dim=0, index=torch.LongTensor(_index).cuda())
                    _feat_s = torch.index_select(feat_s, dim=0, index=_index)
                    tmp_feats_s[l] = _feat_s
                    bs_ = _feat_s.size(0)

                    local_centers_sl = torch.mean(_feat_s, dim=0)
                    tmp_centers_s[l] = local_centers_sl
                    delta_cs = centers_s[l] - local_centers_sl
                    centers_s[l] = centers_s[l] - lr_cs * delta_cs

                    loss_cs_l = criterion_dist(local_centers_sl, centers_s[l])
                    loss_cs += loss_cs_l

                    if flag_intra:
                        m_feat_s = centers_s[l].repeat((bs_, 1))
                        loss_intra_l = criterion_dist(_feat_s, m_feat_s, dim=1)
                        loss_intra += loss_intra_l

            if cfg.SETTING.CLoss:
                loss += loss_cs * w_cs

            for l in centers_t.keys():
                loss_cst_l = criterion_dist(centers_s[l], centers_t[l])
                loss_cst += loss_cst_l

            if cfg.SETTING.CLoss:
                loss += loss_cst * w_cst

            for i in range(4 - 1):
                for j in range(i + 1, 4):
                    loss_inter_ij_s = torch.max(
                        thr_m - criterion_dist(centers_s[i], centers_s[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()
                    loss_inter_ij_t = torch.max(
                        thr_m - criterion_dist(centers_t[i], centers_t[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()
                    '''Add items between two domains'''
                    loss_inter_ij_st = torch.max(
                        thr_m - criterion_dist(centers_s[i], centers_t[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()
                    loss_inter_ij_ts = torch.max(
                        thr_m - criterion_dist(centers_t[i], centers_s[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()

                    loss_inter_ij = (loss_inter_ij_s + loss_inter_ij_t +
                                     loss_inter_ij_st + loss_inter_ij_ts) / 4
                    # loss_inter_ij = (loss_inter_ij_s + loss_inter_ij_t)
                    # loss_inter_ij = loss_inter_ij_s

                    loss_inter += loss_inter_ij

            if flag_inter:
                loss += loss_inter * w_inter
            if flag_intra:
                loss += loss_intra * w_intra

            loss_coral = 0
            if cfg.SETTING.CORAL:
                for l in tmp_feats_s.keys():
                    if l in tmp_feats_t.keys():
                        loss_coral += coral(tmp_feats_s[l], tmp_feats_t[l])
                loss += loss_coral

            optimizer_main.zero_grad()
            loss.backward(retain_graph=True)
            optimizer_main.step()
            scheduler_main.step()
            if args.use_ema:
                ema.update()
                ema.apply_shadow()

            running_lr = optimizer_main.state_dict()['param_groups'][0]['lr']
            torch.cuda.empty_cache()
            for l in centers_s.keys():
                centers_s[l] = centers_s[l].detach()
            for l in centers_t.keys():
                centers_t[l] = centers_t[l].detach()
            center_s = center_s.detach()
            center_t = center_t.detach()

            if idb == iter_num - 1:
                print("[{}, {}] cls loss: {:.4f}, cs loss: {:.4f}, "
                      "ct loss: {:.4f}, cst loss: {:.4f}, mmd loss: {:.4f}, "
                      "inter loss: {:.4f}, intra loss: {:.4f}, "
                      "CORAL: {:.4f}, "
                      "lr: {:.5f}".format(epoch, idb, loss_cls, loss_cs,
                                          loss_ct, loss_cst, loss_mmd,
                                          loss_inter, loss_intra, loss_coral,
                                          running_lr))

                print("The number of pesudo labels and true labels:")
                pprint.pprint(pseudo_label_nums)
                pprint.pprint(true_label_nums)

                torch.save({'model_state_dict': net.state_dict()},
                           osp.join(save_path, '{}.pkl'.format(epoch)))

        if epoch % 10 == 9:
            net.eval()

            preds_entire = []
            labels_entire = []

            with torch.no_grad():
                for idb, data_batch in enumerate(dloader_val):
                    s_batch, l_batch = data_batch

                    s_batch = s_batch.cuda()
                    l_batch = l_batch.numpy()

                    _, logits = net(s_batch)

                    preds_softmax = F.log_softmax(logits, dim=1).exp()
                    preds_softmax_np = preds_softmax.detach().cpu().numpy()
                    preds = np.argmax(preds_softmax_np, axis=1)

                    preds_entire.append(preds)
                    labels_entire.append(l_batch)

                    torch.cuda.empty_cache()

            preds_entire = np.concatenate(preds_entire, axis=0)
            labels_entire = np.concatenate(labels_entire, axis=0)

            Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                               y_label=labels_entire)
            results = evaluator._metrics(predictions=preds_entire,
                                         labels=labels_entire)
            f1_scores = evaluator._f1_score(y_pred=preds_entire,
                                            y_true=labels_entire)
            con_matrix = evaluator._confusion_matrix(preds_entire,
                                                     labels_entire)

            print('The overall accuracy is: {}'.format(results['Acc']))
            print("The confusion matrix is: ")
            print(con_matrix)
            # print('The sklearn metrics are: ')
            print('Pp: ')
            pprint.pprint(Pp)
            print('Se: ')
            pprint.pprint(Se)
            print('The F1 score is: {}'.format(f1_scores))

            if f1_scores[2] >= best_f1_s:
                best_f1_s = f1_scores[2]
                torch.save({"model_state_dict": net.state_dict()},
                           osp.join(save_path, 'best_model.pkl'))

        # Updating thresholds for pesudo labels
        if cfg.SETTING.INCRE_THRS:
            if cfg.SETTING.RE_TRAIN:
                epoch_ = epoch - 2 * pre_train_epochs
            else:
                epoch_ = epoch - pre_train_epochs
            thrs = update_thrs(thrs, epoch_, epochs)
Esempio n. 3
0
def eval(args):

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()

    target = cfg.SETTING.TEST_DATASET
    if args.target:
        target = args.target

    batch_size = cfg.TRAIN.BATCH_SIZE

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)

    # img_path = os.path.join('./figures', exp_id)
    # if not os.path.exists(img_path):
    #     os.makedirs(img_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))

    data_dict_target = load_dataset_to_memory(target)
    target_records = build_validation_records(target)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            aspp_bn=cfg.SETTING.ASPP_BN,
                            aspp_act=cfg.SETTING.ASPP_ACT,
                            lead=cfg.SETTING.LEAD,
                            p=cfg.PARAMETERS.P,
                            dilations=cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
    net = net.cuda()
    net.eval()

    evaluator = Eval(num_class=4)

    print("The network {} has {} "
          "parameters in total".format(cfg.SETTING.NETWORK,
                                       sum(x.numel() for x in net.parameters())))

    dataset = MULTI_ECG_EVAL_DATASET(target,
                                     load_beat_with_rr,
                                     data_dict_target,
                                     test_records=target_records,
                                     beat_num=cfg.SETTING.BEAT_NUM,
                                     fixed_len=cfg.SETTING.FIXED_LEN,
                                     lead=cfg.SETTING.LEAD, unlabel_num=cfg.SETTING.UDA_NUM)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            num_workers=cfg.SYSTEM.NUM_WORKERS)

    print("The size of the validation dataset is {}".format(len(dataset)))

    preds_entire = []
    labels_entire = []
    probs_entire = []
    samples = []

    with torch.no_grad():
        for idb, data_batch in enumerate(dataloader):

            s_batch, l_batch = data_batch

            s_batch = s_batch.cuda()
            l_batch = l_batch.numpy()

            _, logits = net(s_batch)

            preds_softmax = F.log_softmax(logits, dim=1).exp()
            preds_softmax_np = preds_softmax.detach().cpu().numpy()
            preds = np.argmax(preds_softmax_np, axis=1)

            preds_entire.append(preds)
            labels_entire.append(l_batch)
            probs_entire.append(preds_softmax_np)
            samples.append(s_batch.detach().cpu().numpy())

            torch.cuda.empty_cache()

    preds_entire = np.concatenate(preds_entire, axis=0)
    labels_entire = np.concatenate(labels_entire, axis=0)
    probs_entire = np.concatenate(probs_entire, axis=0)
    samples = np.concatenate(samples, axis=0)

    # '''Visualize incorrect samples'''
    # indices_preds_2 = np.argwhere(preds_entire == 2).squeeze(axis=1)
    # labels_ = labels_entire[indices_preds_2]
    # indices_labels_0 = np.argwhere(labels_ == 0).squeeze(axis=1)
    # samples_2_to_0 = samples[indices_preds_2[indices_labels_0]]
    #
    # for k in range(samples_2_to_0.shape[0]):
    #     plt.figure(figsize=(20.5, 15.5))
    #     plt.plot(samples_2_to_0[k, 0, 0])
    #     plt.savefig(os.path.join(img_path, '2_to_0_{}.png'.format(k)), bbox_inches='tight')
    #     plt.close()

    Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                       y_label=labels_entire)
    results = evaluator._metrics(predictions=preds_entire, labels=labels_entire)
    con_matrix = evaluator._confusion_matrix(y_pred=preds_entire,
                                             y_label=labels_entire)

    print('The overall accuracy is: {}'.format(results['Acc']))
    print("The confusion matrix is: ")
    print(con_matrix)
    print('The sklearn metrics are: ')
    print('Pp: ')
    pprint.pprint(Pp)
    print('Se: ')
    pprint.pprint(Se)
    print('The F1 score is: {}'.format(evaluator._f1_score(y_pred=preds_entire, y_true=labels_entire)))
Esempio n. 4
0
def train(args):

    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()
    '''Setting the random seed used in the experiment'''

    if cfg.SETTING.SEED != -1:
        torch.manual_seed(cfg.SETTING.SEED)
        torch.cuda.manual_seed(cfg.SETTING.SEED)

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET

    batch_size = cfg.TRAIN.BATCH_SIZE
    pre_train_epochs = cfg.TRAIN.PRE_TRAIN_EPOCHS
    epochs = cfg.TRAIN.EPOCHS
    lr = cfg.TRAIN.LR
    decay_rate = cfg.TRAIN.DECAY_RATE
    decay_step = cfg.TRAIN.DECAY_STEP
    flag_intra = cfg.SETTING.INTRA_LOSS
    flag_inter = cfg.SETTING.INTER_LOSS
    flag_norm = cfg.SETTING.NORM_ALIGN
    optimizer_ = cfg.SETTING.OPTIMIZER

    w_l2 = cfg.PARAMETERS.W_L2
    w_cls = cfg.PARAMETERS.W_CLS
    w_norm = cfg.PARAMETERS.W_NORM
    w_cs = cfg.PARAMETERS.BETA1
    w_ct = cfg.PARAMETERS.BETA2
    w_cst = cfg.PARAMETERS.BETA
    w_mmd = cfg.PARAMETERS.BETA_MMD
    w_inter = cfg.PARAMETERS.BETA_INTER
    w_intra = cfg.PARAMETERS.BETA_INTRA
    thr_m = cfg.PARAMETERS.THR_M
    thrs_ = cfg.PARAMETERS.THRS
    entropy_w = 0.001

    emsemble_num = cfg.PARAMETERS.EMSEMBLE_NUM
    emsemble_step = cfg.PARAMETERS.EMSEMBLE_STEP

    lr_c = cfg.PARAMETERS.LR_C
    lr_cs = cfg.PARAMETERS.LR_C_S
    lr_ct = cfg.PARAMETERS.LR_C_T

    thrs = {}
    for l in range(len(thrs_)):
        thrs[l] = thrs_[l]

    exp_id = os.path.basename(cfg_dir).split('.')[0]

    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)
    if not osp.exists(save_path):
        os.makedirs(save_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))
    flag_loading = True if osp.exists(check_point_dir) else False

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target) if (
        source != target) else data_dict_source

    transform = augmentation_transform_with_rr if cfg.SETTING.AUGMENTATION else None

    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    dataset = UDA_DATASET(source,
                          target,
                          data_dict_source,
                          data_dict_target,
                          source_records,
                          target_records,
                          cfg.SETTING.UDA_NUM,
                          load_beat_with_rr,
                          transform=transform,
                          beat_num=cfg.SETTING.BEAT_NUM,
                          fixed_len=cfg.SETTING.FIXED_LEN,
                          lead=cfg.SETTING.LEAD)

    dset_val = MULTI_ECG_EVAL_DATASET(target,
                                      load_beat_with_rr,
                                      data_dict_target,
                                      test_records=target_records,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD,
                                      unlabel_num=0)
    dloader_val = DataLoader(dset_val,
                             batch_size=batch_size,
                             num_workers=cfg.SYSTEM.NUM_WORKERS)

    if cfg.TRAIN.IMBALANCE_SAMPLE:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                sampler=UDAImbalancedDatasetSampler(dataset))
    else:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                shuffle=True)

    iter_num = int(len(dataset) / batch_size)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            cfg.SETTING.ASPP_BN,
                            cfg.SETTING.ASPP_ACT,
                            cfg.SETTING.LEAD,
                            cfg.PARAMETERS.P,
                            cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    # Initialization of the model
    net.apply(init_weights)

    teacher_net = build_acnn_models(cfg.SETTING.NETWORK,
                                    cfg.SETTING.ASPP_BN,
                                    cfg.SETTING.ASPP_ACT,
                                    cfg.SETTING.LEAD,
                                    cfg.PARAMETERS.P,
                                    cfg.SETTING.DILATIONS,
                                    act_func=cfg.SETTING.ACT,
                                    f_act_func=cfg.SETTING.F_ACT,
                                    apply_residual=cfg.SETTING.RESIDUAL,
                                    bank_num=cfg.SETTING.BANK_NUM)

    print("The network {} has {} parameters in total".format(
        cfg.SETTING.NETWORK, sum(x.numel() for x in net.parameters())))

    if flag_loading:
        net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
        print("The saved model is loaded.")
    net = net.cuda()

    criterion_cls_4 = loss_function(cfg.SETTING.LOSS,
                                    dataset=source,
                                    num_ew=cfg.PARAMETERS.N,
                                    T=cfg.PARAMETERS.T)
    criterion_dist = build_distance(cfg.SETTING.DISTANCE)

    optimizer_pre = get_optimizer(optimizer_, net.parameters(), lr, w_l2)
    scheduler_pre = optim.lr_scheduler.StepLR(optimizer_pre,
                                              step_size=decay_step,
                                              gamma=decay_rate)

    optimizer_main = get_optimizer(optimizer_, net.parameters(), lr * 0.1,
                                   w_l2)
    scheduler_main = optim.lr_scheduler.StepLR(optimizer_main,
                                               step_size=decay_step * 10,
                                               gamma=decay_rate)
    evaluator = Eval(num_class=4)
    '''Initial and register the EMA'''
    ema = EMA(model=net, decay=0.99)
    ema.register()

    if check_epoch < pre_train_epochs - 1:
        print("Starting STAGE I: pre-training the model using source data")

        best_f1_s = 0.0

        for epoch in range(max(0, check_epoch), pre_train_epochs):
            for idb, data_batch in enumerate(dataloader):
                net.train()

                s_batch, sl_batch, t_batch, tl_batch = data_batch
                s_batch = s_batch.cuda()
                sl_batch = sl_batch.cuda()
                t_batch = t_batch.cuda()
                tl_batch = tl_batch.cuda()

                _, preds = net(s_batch)
                loss = criterion_cls_4(preds, sl_batch)

                # Add an entropy regularizer
                # p_softmax = nn.Softmax(dim=1)(preds)
                # loss -= get_entropy_loss(p_softmax, entropy_w)

                optimizer_pre.zero_grad()
                loss.backward()
                optimizer_pre.step()
                scheduler_pre.step()
                if args.use_ema:
                    ema.update()
                    ema.apply_shadow()

                running_lr = optimizer_pre.state_dict(
                )['param_groups'][0]['lr']

                print("[{}, {}] cls loss: {:.4f}, lr: {:.4f}".format(
                    epoch, idb, loss, running_lr),
                      end='\r')
                if idb == iter_num - 1:
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, '{}.pkl'.format(epoch)))

            if epoch % 10 == 9:
                net.eval()
                preds_entire = []
                labels_entire = []

                with torch.no_grad():
                    for idb, data_batch in enumerate(dloader_val):
                        s_batch, l_batch = data_batch

                        s_batch = s_batch.cuda()
                        l_batch = l_batch.numpy()

                        _, logits = net(s_batch)

                        preds_softmax = F.log_softmax(logits, dim=1).exp()
                        preds_softmax_np = preds_softmax.detach().cpu().numpy()
                        preds = np.argmax(preds_softmax_np, axis=1)

                        preds_entire.append(preds)
                        labels_entire.append(l_batch)

                        torch.cuda.empty_cache()

                preds_entire = np.concatenate(preds_entire, axis=0)
                labels_entire = np.concatenate(labels_entire, axis=0)

                Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                                   y_label=labels_entire)
                results = evaluator._metrics(predictions=preds_entire,
                                             labels=labels_entire)
                # con_matrix = evaluator._confusion_matrix(y_pred=preds_entire,
                #                                          y_label=labels_entire)

                f1_scores = evaluator._f1_score(y_pred=preds_entire,
                                                y_true=labels_entire)

                print('The overall accuracy is: {}'.format(results['Acc']))
                print("The confusion matrix is: ")
                print('Pp: ')
                pprint.pprint(Pp)
                print('Se: ')
                pprint.pprint(Se)
                print('The F1 score is: {}'.format(f1_scores))

                if f1_scores[2] >= best_f1_s:
                    best_f1_s = f1_scores[2]
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, 'best_model.pkl'))

                torch.cuda.empty_cache()