Пример #1
0
def vis(args):
    colors = {0: 'red', 1: 'blue', 2: 'green', 3: 'black'}
    metrics = {'Cosine': 'cosine', 'L2': 'euclidean'}

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()
    print(cfg)

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET
    if args.target:
        target = args.target

    batch_size = cfg.TRAIN.BATCH_SIZE
    distance = cfg.SETTING.DISTANCE

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target)
    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    net = build_cnn_models(cfg.SETTING.NETWORK, cfg.SETTING.FIXED_LEN)
    net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
    net.cuda()
    net.eval()

    dataset_source = MULTI_ECG_EVAL_DATASET(source,
                                            load_beat_with_rr,
                                            data_dict_source,
                                            test_records=source_records,
                                            beat_num=cfg.SETTING.BEAT_NUM,
                                            fixed_len=cfg.SETTING.FIXED_LEN)
    dataloader_source = DataLoader(
        dataset_source,
        batch_size=batch_size,
        num_workers=cfg.SYSTEM.NUM_WORKERS,
        sampler=ImbalancedDatasetSampler(dataset_source))

    dataset_target = MULTI_ECG_EVAL_DATASET(target,
                                            load_beat_with_rr,
                                            data_dict_target,
                                            test_records=target_records,
                                            beat_num=cfg.SETTING.BEAT_NUM,
                                            fixed_len=cfg.SETTING.FIXED_LEN)
    dataloader_target = DataLoader(
        dataset_target,
        batch_size=batch_size,
        num_workers=cfg.SYSTEM.NUM_WORKERS,
        sampler=ImbalancedDatasetSampler(dataset_target))

    features_source = []
    features_target = []
    labels_source = []
    labels_target = []

    source_logits = []
    target_logits = []

    raw_signals_source = []
    raw_signals_target = []

    with torch.no_grad():

        tsne = manifold.TSNE(n_components=2,
                             metric=metrics[distance],
                             perplexity=30,
                             early_exaggeration=4.0,
                             learning_rate=500.0,
                             n_iter=2000,
                             init='pca',
                             random_state=2389)

        for idb, data_batch in enumerate(dataloader_source):
            s_batch, l_batch, sr_batch, _ = data_batch
            s_batch_cpu = s_batch.detach().numpy()
            s_batch = s_batch.unsqueeze(dim=1)
            s_batch = s_batch.cuda()
            sr_batch = sr_batch.cuda()

            _, _, features_s, logits_s = net(s_batch, sr_batch)

            feat_s_cpu = features_s.detach().cpu().numpy()
            logits_s_cpu = logits_s.detach().cpu().numpy()

            source_logits.append(logits_s_cpu)
            features_source.append(feat_s_cpu)
            raw_signals_source.append(s_batch_cpu)
            labels_source.append(l_batch)

            if idb == args.epochs - 1:
                break

        for idb, data_batch in enumerate(dataloader_target):
            s_batch, l_batch, tr_batch, _ = data_batch
            s_batch_cpu = s_batch.detach().numpy()
            s_batch = s_batch.unsqueeze(dim=1)
            s_batch = s_batch.cuda()
            tr_batch = tr_batch.cuda()

            _, _, features_t, logits_t = net(s_batch, tr_batch)

            feat_t_cpu = features_t.detach().cpu().numpy()
            logits_t_cpu = logits_t.detach().cpu().numpy()

            target_logits.append(logits_t_cpu)
            features_target.append(feat_t_cpu)
            raw_signals_target.append(s_batch_cpu)
            labels_target.append(l_batch)

            if idb == args.epochs - 1:
                break

        labels_source = np.concatenate(labels_source, axis=0)
        labels_target = np.concatenate(labels_target, axis=0)

        labels = np.concatenate([labels_source, labels_target], axis=0)

        count_source = {'N': 0, 'V': 0, 'S': 0, 'F': 0}
        count_target = {'N': 0, 'V': 0, 'S': 0, 'F': 0}
        keys = ['N', 'V', 'S', 'F']

        num_source = len(labels_source)
        num_target = len(labels_target)

        for i in range(num_source):
            count_source[keys[labels_source[i]]] += 1
        for j in range(num_target):
            count_target[keys[labels_target[j]]] += 1

        for k in keys:
            print('The number of {} in source: {}; in target: {}'.format(
                k, count_source[k], count_target[k]))

        features_source = np.concatenate(features_source, axis=0)
        features_target = np.concatenate(features_target, axis=0)

        features = np.concatenate([features_source, features_target], axis=0)

        if args.component == 'entire':
            features = features
        elif args.component == 'wave':
            features = features[:, 0:512]
        else:
            features = features[:, 512:]

        feat_tsne = tsne.fit_transform(features)

        x_min, x_max = feat_tsne.min(0), feat_tsne.max(0)
        feat_norm = (feat_tsne - x_min) / (x_max - x_min)
        '''The class-specific view'''
        plt.figure(figsize=(20, 20))
        for i in range(feat_norm.shape[0]):
            if i < num_source:
                plt.scatter(feat_norm[i, 0],
                            feat_norm[i, 1],
                            marker='.',
                            color=colors[labels[i]])
            else:
                plt.scatter(feat_norm[i, 0],
                            feat_norm[i, 1],
                            marker='x',
                            color=colors[labels[i]])
        plt.xticks([])
        plt.yticks([])
        img_save_path = 'figures/tsne_{}_{}_{}_cls.png'.format(
            exp_id, args.check_epoch, args.component)
        plt.savefig(img_save_path, bbox_inches='tight')
        plt.close()
        '''The domain-specific view'''
Пример #2
0
def eval_epochs():

    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        default=None,
                        type=str,
                        help='The directory of config .yaml file')
    # parser.add_argument('--check_epoch', default=-1, type=int,
    #                     help='The checkpoint ID for recovering training procedure')
    parser.add_argument('--target',
                        dest='target',
                        default=None,
                        type=str,
                        choices=['mitdb', 'fmitdb', 'svdb', 'incartdb'],
                        help='One can choose another dataset for validation '
                        'beyond the training setting up')
    args = parser.parse_args()

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()

    target = cfg.SETTING.TEST_DATASET
    if args.target:
        target = args.target

    batch_size = cfg.TRAIN.BATCH_SIZE

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)

    img_path = os.path.join('./figures', exp_id)
    if not os.path.exists(img_path):
        os.makedirs(img_path)
    res_path = os.path.join('./results', exp_id)

    result_files = os.listdir(res_path)
    result_files = [
        filename for filename in result_files
        if filename.split('.')[1] == 'npz'
    ]

    F1 = []
    SE = []
    PP = []
    checkepochs = []

    for filename in result_files:
        checkepoch = int(filename.split('.')[0].split('_')[1])

        if checkepoch >= 0:
            checkepochs.append(checkepoch)
            data = np.load(os.path.join(res_path, filename))
            se = data['Se']
            pp = data['Pp']
            f1 = data['F1']

            SE.append(se)
            PP.append(pp)
            F1.append(f1)

    F1 = np.stack(F1, axis=1)
    SE = np.stack(SE, axis=1)
    PP = np.stack(PP, axis=1)
    checkepochs = np.sort(np.array(checkepochs))

    plt.figure(figsize=(20, 15))
    plt.plot(checkepochs, F1[0], color='red', linestyle='-', label='N')
    plt.plot(checkepochs, F1[1], color='blue', linestyle='-', label='V')
    plt.plot(checkepochs, F1[2], color='green', linestyle='-', label='S')
    plt.plot(checkepochs, F1[3], color='orange', linestyle='-', label='F')
    plt.legend(loc='upper right')
    plt.savefig(os.path.join(res_path, 'F1.png'), bbox_inches='tight')
    plt.close()

    plt.figure(figsize=(20, 15))
    plt.plot(checkepochs, SE[0], color='red', linestyle='-', label='N')
    plt.plot(checkepochs, SE[1], color='blue', linestyle='-', label='V')
    plt.plot(checkepochs, SE[2], color='green', linestyle='-', label='S')
    plt.plot(checkepochs, SE[3], color='orange', linestyle='-', label='F')
    plt.legend(loc='upper right')
    plt.savefig(os.path.join(res_path, 'sen.png'), bbox_inches='tight')
    plt.close()

    plt.figure(figsize=(20, 15))
    plt.plot(checkepochs, PP[0], color='red', linestyle='-', label='N')
    plt.plot(checkepochs, PP[1], color='blue', linestyle='-', label='V')
    plt.plot(checkepochs, PP[2], color='green', linestyle='-', label='S')
    plt.plot(checkepochs, PP[3], color='orange', linestyle='-', label='F')
    plt.legend(loc='upper right')
    plt.savefig(os.path.join(res_path, 'pre.png'), bbox_inches='tight')
    plt.close()
Пример #3
0
def vis(args):
    colors_s = {0: 'red', 1: 'blue', 2: 'green', 3: 'black'}
    colors_t = {0: 'lightcoral', 1: 'lightskyblue', 2: 'lightgreen', 3: 'gray'}
    categories = {0: 'N', 1: 'V', 2: 'S', 3: 'F'}
    metrics = {'Cosine': 'cosine', 'L2': 'euclidean'}

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET
    if args.target:
        target = args.target

    batch_size = cfg.TRAIN.BATCH_SIZE
    distance = args.distance if args.distance else cfg.SETTING.DISTANCE

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)

    img_path = os.path.join('./figures', exp_id)
    if not os.path.exists(img_path):
        os.makedirs(img_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target)
    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            cfg.SETTING.ASPP_BN,
                            cfg.SETTING.ASPP_ACT,
                            cfg.SETTING.LEAD,
                            cfg.PARAMETERS.P,
                            cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
    net.cuda()
    net.eval()

    dataset_source = MULTI_ECG_EVAL_DATASET(source,
                                            load_beat_with_rr,
                                            data_dict_source,
                                            test_records=source_records,
                                            beat_num=cfg.SETTING.BEAT_NUM,
                                            fixed_len=cfg.SETTING.FIXED_LEN,
                                            lead=cfg.SETTING.LEAD)
    dataloader_source = DataLoader(
        dataset_source,
        batch_size=batch_size,
        num_workers=cfg.SYSTEM.NUM_WORKERS,
        sampler=ImbalancedDatasetSampler(dataset_source))

    dataset_target = MULTI_ECG_EVAL_DATASET(target,
                                            load_beat_with_rr,
                                            data_dict_target,
                                            test_records=target_records,
                                            beat_num=cfg.SETTING.BEAT_NUM,
                                            fixed_len=cfg.SETTING.FIXED_LEN,
                                            lead=cfg.SETTING.LEAD)
    dataloader_target = DataLoader(
        dataset_target,
        batch_size=batch_size,
        num_workers=cfg.SYSTEM.NUM_WORKERS,
        sampler=ImbalancedDatasetSampler(dataset_target))

    features_source = []
    features_target = []
    labels_source = []
    labels_target = []

    source_logits = []
    target_logits = []
    source_probs = []
    target_probs = []

    raw_signals_source = []
    raw_signals_target = []

    with torch.no_grad():

        tsne = manifold.TSNE(n_components=2,
                             metric=metrics[distance],
                             perplexity=30,
                             early_exaggeration=4.0,
                             learning_rate=500.0,
                             n_iter=2000,
                             init='pca',
                             random_state=2389)

        for idb, data_batch in enumerate(dataloader_source):
            s_batch, l_batch = data_batch
            s_batch_cpu = s_batch.detach().numpy()
            s_batch = s_batch.cuda()

            features_s, logits_s = net(s_batch)

            # feats = net.get_feature_maps(s_batch)
            # feats = feats.detach().cpu().numpy()
            # plt.figure(figsize=(12.5, 10))
            # plt.plot(feats[0])
            # plt.savefig(osp.join(img_path, '{}.png'.format(idb)), bbox_inches='tight')
            # plt.close()

            feat_s_cpu = features_s.detach().cpu().numpy()
            logits_s_cpu = logits_s.detach().cpu().numpy()
            probs_s = F.log_softmax(logits_s,
                                    dim=1).exp().detach().cpu().numpy()

            source_logits.append(logits_s_cpu)
            source_probs.append(probs_s)
            features_source.append(feat_s_cpu)
            raw_signals_source.append(s_batch_cpu)
            labels_source.append(l_batch)

            if idb == args.epochs - 1:
                break

        for idb, data_batch in enumerate(dataloader_target):
            s_batch, l_batch = data_batch
            s_batch_cpu = s_batch.detach().numpy()
            s_batch = s_batch.cuda()

            features_t, logits_t = net(s_batch)

            feat_t_cpu = features_t.detach().cpu().numpy()
            logits_t_cpu = logits_t.detach().cpu().numpy()
            probs_t = F.log_softmax(logits_t,
                                    dim=1).exp().detach().cpu().numpy()

            target_logits.append(logits_t_cpu)
            target_probs.append(probs_t)
            features_target.append(feat_t_cpu)
            raw_signals_target.append(s_batch_cpu)
            labels_target.append(l_batch)

            if idb == args.epochs - 1:
                break

        labels_source = np.concatenate(labels_source, axis=0)
        labels_target = np.concatenate(labels_target, axis=0)

        # target_probs = np.concatenate(target_probs, axis=0)
        # preds_t = np.argmax(target_probs, axis=1)
        # probs_t = np.max(target_probs, axis=1)
        #
        # for l in range(4):
        #     indices_tl = np.argwhere(preds_t == l)
        #     if len(indices_tl) > 0:
        #         indices_tl = indices_tl.squeeze(axis=1)
        #         probs_tl = probs_t[indices_tl]
        #         gt_tl = labels_target[indices_tl]
        #         indices_l = np.where(gt_tl == l, 1, 0)
        #
        #         plt.figure(figsize=(20, 15))
        #         n, bins, patches = plt.hist(probs_tl, bins=300)
        #         plt.savefig(osp.join(img_path, 'cls_{}.png'.format(l)), bbox_inches='tight')
        #         plt.close()
        #
        #         corr_indices_l = np.argwhere(indices_l == 1)
        #         incorr_indices_l = np.argwhere(indices_l == 0)
        #
        #         if len(corr_indices_l):
        #             plt.figure(figsize=(20, 15))
        #             corr_indices_l = corr_indices_l.squeeze(axis=1)
        #             corr_probs_tl = probs_tl[corr_indices_l]
        #             _, _, _ = plt.hist(corr_probs_tl, bins=300)
        #             plt.savefig(osp.join(img_path, 'corr_cls{}.png'.format(l)), bbox_inches='tight')
        #             plt.close()
        #         if len(incorr_indices_l):
        #             plt.figure(figsize=(20, 15))
        #             incorr_indices_l = incorr_indices_l.squeeze(axis=1)
        #             incorr_probs_tl = probs_tl[incorr_indices_l]
        #             _, _, _ = plt.hist(incorr_probs_tl, bins=300, color='red')
        #             plt.savefig(osp.join(img_path, 'incorr_cls{}.png'.format(l)), bbox_inches='tight')
        #             plt.close()

        labels = np.concatenate([labels_source, labels_target], axis=0)

        count_source = {'N': 0, 'V': 0, 'S': 0, 'F': 0}
        count_target = {'N': 0, 'V': 0, 'S': 0, 'F': 0}
        keys = ['N', 'V', 'S', 'F']

        num_source = len(labels_source)
        num_target = len(labels_target)

        for i in range(num_source):
            count_source[keys[labels_source[i]]] += 1
        for j in range(num_target):
            count_target[keys[labels_target[j]]] += 1

        for k in keys:
            print('The number of {} in source: {}; in target: {}'.format(
                k, count_source[k], count_target[k]))

        features_source = np.concatenate(features_source, axis=0)
        features_target = np.concatenate(features_target, axis=0)

        features = np.concatenate([features_source, features_target], axis=0)
        feat_tsne = tsne.fit_transform(features)

        x_min, x_max = feat_tsne.min(0), feat_tsne.max(0)
        feat_norm = (feat_tsne - x_min) / (x_max - x_min)

        feat_norm_s = feat_norm[0:num_source]
        feat_norm_t = feat_norm[num_source:num_target + num_source]

        features_s_dict = {}
        feat_norm_s_dict = {}
        features_t_dict = {}
        feat_norm_t_dict = {}
        for l in range(4):
            l_indices = np.argwhere(labels_source == l).squeeze(axis=1)
            features_s_dict[l] = features_source[l_indices]
            feat_norm_s_dict[l] = feat_norm_s[l_indices]

            l_indices_t = np.argwhere(labels_target == l).squeeze(axis=1)

            features_t_dict[l] = features_target[l_indices_t]
            feat_norm_t_dict[l] = feat_norm_t[l_indices_t]
        '''The feature visualization'''
        # plt.figure(figsize=(30, 15))
        # for i in range(features_source.shape[0]):
        #     if labels_source[i] == 0:
        #         plt.subplot(411)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        #     elif labels_source[i] == 1:
        #         plt.subplot(412)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        #     elif labels_source[i] == 2:
        #         plt.subplot(413)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        #     else:
        #         plt.subplot(414)
        #         plt.plot(features_source[i], color=colors[labels_source[i]])
        # img_save_path = osp.join(img_path, 'feat_s_{}_{}.png'.format(exp_id, args.check_epoch))
        # plt.savefig(img_save_path, bbox_inches='tight')
        # plt.close()
        #
        # plt.figure(figsize=(30, 15))
        # for i in range(features_target.shape[0]):
        #     if labels_target[i] == 0:
        #         plt.subplot(411)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        #     elif labels_target[i] == 1:
        #         plt.subplot(412)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        #     elif labels_target[i] == 2:
        #         plt.subplot(413)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        #     else:
        #         plt.subplot(414)
        #         plt.plot(features_target[i], color=colors[labels_target[i]])
        # img_save_path = osp.join(img_path, 'feat_t_{}_{}.png'.format(exp_id, args.check_epoch))
        # plt.savefig(img_save_path, bbox_inches='tight')
        # plt.close()
        '''The class-specific view'''

        if 'mitdb' in target:
            plt.figure(figsize=(20, 20))
            for l in range(4):
                plt.scatter(feat_norm_s_dict[l][:, 0],
                            feat_norm_s_dict[l][:, 1],
                            marker='o',
                            color=colors_s[l],
                            label='source {}'.format(categories[l]))
                plt.scatter(feat_norm_t_dict[l][:, 0],
                            feat_norm_t_dict[l][:, 1],
                            marker='X',
                            color=colors_t[l],
                            label='target {}'.format(categories[l]))
            plt.xticks([])
            plt.yticks([])
            plt.legend(loc='upper right', fontsize=30)
            img_save_path = osp.join(
                img_path, 'tsne_{}_{}_cls.png'.format(exp_id,
                                                      args.check_epoch))
            plt.savefig(img_save_path, bbox_inches='tight')
            plt.close()

        else:
            plt.figure(figsize=(20, 20))
            for l in range(3):
                plt.scatter(feat_norm_s_dict[l][:, 0],
                            feat_norm_s_dict[l][:, 1],
                            marker='o',
                            color=colors_s[l],
                            label='source {}'.format(categories[l]))
                plt.scatter(feat_norm_t_dict[l][:, 0],
                            feat_norm_t_dict[l][:, 1],
                            marker='X',
                            color=colors_t[l],
                            label='target {}'.format(categories[l]))
            plt.xticks([])
            plt.yticks([])
            plt.legend(loc='upper right', fontsize=30)
            img_save_path = osp.join(
                img_path, 'tsne_{}_{}_cls.png'.format(exp_id,
                                                      args.check_epoch))
            plt.savefig(img_save_path, bbox_inches='tight')
            plt.close()
        '''The domain-specific view'''
Пример #4
0
def eval(args):

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()

    target = cfg.SETTING.TEST_DATASET
    if args.target:
        target = args.target

    batch_size = cfg.TRAIN.BATCH_SIZE

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))

    data_dict_target = load_dataset_to_memory(target)
    target_records = build_validation_records(target)

    net = build_cnn_models(cfg.SETTING.NETWORK,
                           fixed_len=cfg.SETTING.FIXED_LEN,
                           p=cfg.PARAMETERS.P)
    net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
    net = net.cuda()
    net.eval()

    evaluator = Eval(num_class=4)

    print("The network {} has {} "
          "parameters in total".format(
              cfg.SETTING.NETWORK, sum(x.numel() for x in net.parameters())))

    dataset = MULTI_ECG_EVAL_DATASET(target,
                                     load_beat_with_rr,
                                     data_dict_target,
                                     test_records=target_records,
                                     beat_num=cfg.SETTING.BEAT_NUM,
                                     fixed_len=cfg.SETTING.FIXED_LEN)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=cfg.SYSTEM.NUM_WORKERS)

    print("The size of the validation dataset is {}".format(len(dataset)))

    preds_entire = []
    labels_entire = []

    with torch.no_grad():
        for idb, data_batch in enumerate(dataloader):

            s_batch, l_batch, r_batch, b_batch = data_batch

            s_batch = s_batch.unsqueeze(dim=1)
            s_batch = s_batch.cuda()
            r_batch = r_batch.cuda()
            l_batch = l_batch.numpy()

            _, _, _, logits = net(s_batch, r_batch)

            preds_softmax = F.log_softmax(logits, dim=1).exp()
            preds_softmax_np = preds_softmax.detach().cpu().numpy()
            preds = np.argmax(preds_softmax_np, axis=1)

            preds_entire.append(preds)
            labels_entire.append(l_batch)

            torch.cuda.empty_cache()

    preds_entire = np.concatenate(preds_entire, axis=0)
    labels_entire = np.concatenate(labels_entire, axis=0)

    results = evaluator._metrics(predictions=preds_entire,
                                 labels=labels_entire)

    Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                       y_label=labels_entire)

    con_matrix = evaluator._confusion_matrix(y_pred=preds_entire,
                                             y_label=labels_entire)

    pprint.pprint(results)

    print("The confusion matrix is: ")
    print(con_matrix)
    print('The sklearn metrics are: ')
    print('Pp: ')
    pprint.pprint(Pp)
    print('Se: ')
    pprint.pprint(Se)
    print('The F1 score is: {}'.format(
        evaluator._f1_score(y_pred=preds_entire, y_true=labels_entire)))
Пример #5
0
def eval(args):

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()

    target = cfg.SETTING.TEST_DATASET
    if args.target:
        target = args.target

    batch_size = cfg.TRAIN.BATCH_SIZE

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)

    # img_path = os.path.join('./figures', exp_id)
    # if not os.path.exists(img_path):
    #     os.makedirs(img_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))

    data_dict_target = load_dataset_to_memory(target)
    target_records = build_validation_records(target)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            aspp_bn=cfg.SETTING.ASPP_BN,
                            aspp_act=cfg.SETTING.ASPP_ACT,
                            lead=cfg.SETTING.LEAD,
                            p=cfg.PARAMETERS.P,
                            dilations=cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
    net = net.cuda()
    net.eval()

    evaluator = Eval(num_class=4)

    print("The network {} has {} "
          "parameters in total".format(cfg.SETTING.NETWORK,
                                       sum(x.numel() for x in net.parameters())))

    dataset = MULTI_ECG_EVAL_DATASET(target,
                                     load_beat_with_rr,
                                     data_dict_target,
                                     test_records=target_records,
                                     beat_num=cfg.SETTING.BEAT_NUM,
                                     fixed_len=cfg.SETTING.FIXED_LEN,
                                     lead=cfg.SETTING.LEAD, unlabel_num=cfg.SETTING.UDA_NUM)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            num_workers=cfg.SYSTEM.NUM_WORKERS)

    print("The size of the validation dataset is {}".format(len(dataset)))

    preds_entire = []
    labels_entire = []
    probs_entire = []
    samples = []

    with torch.no_grad():
        for idb, data_batch in enumerate(dataloader):

            s_batch, l_batch = data_batch

            s_batch = s_batch.cuda()
            l_batch = l_batch.numpy()

            _, logits = net(s_batch)

            preds_softmax = F.log_softmax(logits, dim=1).exp()
            preds_softmax_np = preds_softmax.detach().cpu().numpy()
            preds = np.argmax(preds_softmax_np, axis=1)

            preds_entire.append(preds)
            labels_entire.append(l_batch)
            probs_entire.append(preds_softmax_np)
            samples.append(s_batch.detach().cpu().numpy())

            torch.cuda.empty_cache()

    preds_entire = np.concatenate(preds_entire, axis=0)
    labels_entire = np.concatenate(labels_entire, axis=0)
    probs_entire = np.concatenate(probs_entire, axis=0)
    samples = np.concatenate(samples, axis=0)

    # '''Visualize incorrect samples'''
    # indices_preds_2 = np.argwhere(preds_entire == 2).squeeze(axis=1)
    # labels_ = labels_entire[indices_preds_2]
    # indices_labels_0 = np.argwhere(labels_ == 0).squeeze(axis=1)
    # samples_2_to_0 = samples[indices_preds_2[indices_labels_0]]
    #
    # for k in range(samples_2_to_0.shape[0]):
    #     plt.figure(figsize=(20.5, 15.5))
    #     plt.plot(samples_2_to_0[k, 0, 0])
    #     plt.savefig(os.path.join(img_path, '2_to_0_{}.png'.format(k)), bbox_inches='tight')
    #     plt.close()

    Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                       y_label=labels_entire)
    results = evaluator._metrics(predictions=preds_entire, labels=labels_entire)
    con_matrix = evaluator._confusion_matrix(y_pred=preds_entire,
                                             y_label=labels_entire)

    print('The overall accuracy is: {}'.format(results['Acc']))
    print("The confusion matrix is: ")
    print(con_matrix)
    print('The sklearn metrics are: ')
    print('Pp: ')
    pprint.pprint(Pp)
    print('Se: ')
    pprint.pprint(Se)
    print('The F1 score is: {}'.format(evaluator._f1_score(y_pred=preds_entire, y_true=labels_entire)))
Пример #6
0
def train(args):

    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()
    '''Setting the random seed used in the experiment'''

    if cfg.SETTING.SEED != -1:
        torch.manual_seed(cfg.SETTING.SEED)
        torch.cuda.manual_seed(cfg.SETTING.SEED)

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET

    batch_size = cfg.TRAIN.BATCH_SIZE
    pre_train_epochs = cfg.TRAIN.PRE_TRAIN_EPOCHS
    epochs = cfg.TRAIN.EPOCHS
    lr = cfg.TRAIN.LR
    decay_rate = cfg.TRAIN.DECAY_RATE
    decay_step = cfg.TRAIN.DECAY_STEP
    flag_intra = cfg.SETTING.INTRA_LOSS
    flag_inter = cfg.SETTING.INTER_LOSS
    flag_norm = cfg.SETTING.NORM_ALIGN
    optimizer_ = cfg.SETTING.OPTIMIZER

    w_l2 = cfg.PARAMETERS.W_L2
    w_cls = cfg.PARAMETERS.W_CLS
    w_norm = cfg.PARAMETERS.W_NORM
    w_cs = cfg.PARAMETERS.BETA1
    w_ct = cfg.PARAMETERS.BETA2
    w_cst = cfg.PARAMETERS.BETA
    w_mmd = cfg.PARAMETERS.BETA_MMD
    w_inter = cfg.PARAMETERS.BETA_INTER
    w_intra = cfg.PARAMETERS.BETA_INTRA
    thr_m = cfg.PARAMETERS.THR_M
    thrs_ = cfg.PARAMETERS.THRS
    entropy_w = 0.001

    emsemble_num = cfg.PARAMETERS.EMSEMBLE_NUM
    emsemble_step = cfg.PARAMETERS.EMSEMBLE_STEP

    lr_c = cfg.PARAMETERS.LR_C
    lr_cs = cfg.PARAMETERS.LR_C_S
    lr_ct = cfg.PARAMETERS.LR_C_T

    thrs = {}
    for l in range(len(thrs_)):
        thrs[l] = thrs_[l]

    exp_id = os.path.basename(cfg_dir).split('.')[0]

    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)
    if not osp.exists(save_path):
        os.makedirs(save_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))
    flag_loading = True if osp.exists(check_point_dir) else False

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target) if (
        source != target) else data_dict_source

    transform = augmentation_transform_with_rr if cfg.SETTING.AUGMENTATION else None

    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    dataset = UDA_DATASET(source,
                          target,
                          data_dict_source,
                          data_dict_target,
                          source_records,
                          target_records,
                          cfg.SETTING.UDA_NUM,
                          load_beat_with_rr,
                          transform=transform,
                          beat_num=cfg.SETTING.BEAT_NUM,
                          fixed_len=cfg.SETTING.FIXED_LEN,
                          lead=cfg.SETTING.LEAD)

    dset_val = MULTI_ECG_EVAL_DATASET(target,
                                      load_beat_with_rr,
                                      data_dict_target,
                                      test_records=target_records,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD,
                                      unlabel_num=0)
    dloader_val = DataLoader(dset_val,
                             batch_size=batch_size,
                             num_workers=cfg.SYSTEM.NUM_WORKERS)

    if cfg.TRAIN.IMBALANCE_SAMPLE:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                sampler=UDAImbalancedDatasetSampler(dataset))
    else:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                shuffle=True)

    iter_num = int(len(dataset) / batch_size)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            cfg.SETTING.ASPP_BN,
                            cfg.SETTING.ASPP_ACT,
                            cfg.SETTING.LEAD,
                            cfg.PARAMETERS.P,
                            cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    # Initialization of the model
    net.apply(init_weights)

    teacher_net = build_acnn_models(cfg.SETTING.NETWORK,
                                    cfg.SETTING.ASPP_BN,
                                    cfg.SETTING.ASPP_ACT,
                                    cfg.SETTING.LEAD,
                                    cfg.PARAMETERS.P,
                                    cfg.SETTING.DILATIONS,
                                    act_func=cfg.SETTING.ACT,
                                    f_act_func=cfg.SETTING.F_ACT,
                                    apply_residual=cfg.SETTING.RESIDUAL,
                                    bank_num=cfg.SETTING.BANK_NUM)

    print("The network {} has {} parameters in total".format(
        cfg.SETTING.NETWORK, sum(x.numel() for x in net.parameters())))

    if flag_loading:
        net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
        print("The saved model is loaded.")
    net = net.cuda()

    criterion_cls_4 = loss_function(cfg.SETTING.LOSS,
                                    dataset=source,
                                    num_ew=cfg.PARAMETERS.N,
                                    T=cfg.PARAMETERS.T)
    criterion_dist = build_distance(cfg.SETTING.DISTANCE)

    optimizer_pre = get_optimizer(optimizer_, net.parameters(), lr, w_l2)
    scheduler_pre = optim.lr_scheduler.StepLR(optimizer_pre,
                                              step_size=decay_step,
                                              gamma=decay_rate)

    optimizer_main = get_optimizer(optimizer_, net.parameters(), lr * 0.1,
                                   w_l2)
    scheduler_main = optim.lr_scheduler.StepLR(optimizer_main,
                                               step_size=decay_step * 10,
                                               gamma=decay_rate)
    evaluator = Eval(num_class=4)
    '''Initial and register the EMA'''
    ema = EMA(model=net, decay=0.99)
    ema.register()

    if check_epoch < pre_train_epochs - 1:
        print("Starting STAGE I: pre-training the model using source data")

        best_f1_s = 0.0

        for epoch in range(max(0, check_epoch), pre_train_epochs):
            for idb, data_batch in enumerate(dataloader):
                net.train()

                s_batch, sl_batch, t_batch, tl_batch = data_batch
                s_batch = s_batch.cuda()
                sl_batch = sl_batch.cuda()
                t_batch = t_batch.cuda()
                tl_batch = tl_batch.cuda()

                _, preds = net(s_batch)
                loss = criterion_cls_4(preds, sl_batch)

                # Add an entropy regularizer
                # p_softmax = nn.Softmax(dim=1)(preds)
                # loss -= get_entropy_loss(p_softmax, entropy_w)

                optimizer_pre.zero_grad()
                loss.backward()
                optimizer_pre.step()
                scheduler_pre.step()
                if args.use_ema:
                    ema.update()
                    ema.apply_shadow()

                running_lr = optimizer_pre.state_dict(
                )['param_groups'][0]['lr']

                print("[{}, {}] cls loss: {:.4f}, lr: {:.4f}".format(
                    epoch, idb, loss, running_lr),
                      end='\r')
                if idb == iter_num - 1:
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, '{}.pkl'.format(epoch)))

            if epoch % 10 == 9:
                net.eval()
                preds_entire = []
                labels_entire = []

                with torch.no_grad():
                    for idb, data_batch in enumerate(dloader_val):
                        s_batch, l_batch = data_batch

                        s_batch = s_batch.cuda()
                        l_batch = l_batch.numpy()

                        _, logits = net(s_batch)

                        preds_softmax = F.log_softmax(logits, dim=1).exp()
                        preds_softmax_np = preds_softmax.detach().cpu().numpy()
                        preds = np.argmax(preds_softmax_np, axis=1)

                        preds_entire.append(preds)
                        labels_entire.append(l_batch)

                        torch.cuda.empty_cache()

                preds_entire = np.concatenate(preds_entire, axis=0)
                labels_entire = np.concatenate(labels_entire, axis=0)

                Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                                   y_label=labels_entire)
                results = evaluator._metrics(predictions=preds_entire,
                                             labels=labels_entire)
                # con_matrix = evaluator._confusion_matrix(y_pred=preds_entire,
                #                                          y_label=labels_entire)

                f1_scores = evaluator._f1_score(y_pred=preds_entire,
                                                y_true=labels_entire)

                print('The overall accuracy is: {}'.format(results['Acc']))
                print("The confusion matrix is: ")
                print('Pp: ')
                pprint.pprint(Pp)
                print('Se: ')
                pprint.pprint(Se)
                print('The F1 score is: {}'.format(f1_scores))

                if f1_scores[2] >= best_f1_s:
                    best_f1_s = f1_scores[2]
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, 'best_model.pkl'))

                torch.cuda.empty_cache()

    print('Start obtaining centers of each cluster and distribution')

    best_model_dir = osp.join(save_path, 'best_model.pkl')
    net.load_state_dict(torch.load(best_model_dir)['model_state_dict'])

    centers_source_dir = osp.join(save_path, "centers_source.mat")
    centers_target_dir = osp.join(save_path, "centers_target.mat")
    flag_centers = osp.exists(centers_source_dir) and osp.exists(
        centers_target_dir)

    center_source_dir = osp.join(save_path, "center_s.mat")
    center_target_dir = osp.join(save_path, "center_t.mat")
    flag_center = osp.exists(center_source_dir) and osp.exists(
        center_target_dir)

    if flag_centers:
        centers_s_ = sio.loadmat(centers_source_dir)
        centers_t_ = sio.loadmat(centers_target_dir)
        centers_s = {}
        centers_t = {}
        for l in range(4):
            if 'c{}'.format(l) in centers_s_.keys():
                centers_s[l] = torch.from_numpy(
                    centers_s_['c{}'.format(l)].squeeze()).cuda()
            if 'c{}'.format(l) in centers_t_.keys():
                centers_t[l] = torch.from_numpy(
                    centers_t_['c{}'.format(l)].squeeze()).cuda()
    else:
        net.eval()
        centers_s, counter_s = init_source_centers(
            net,
            source,
            source_records,
            data_dict_source,
            batch_size=batch_size,
            num_workers=cfg.SYSTEM.NUM_WORKERS,
            beat_num=cfg.SETTING.BEAT_NUM,
            fixed_len=cfg.SETTING.FIXED_LEN,
            lead=cfg.SETTING.LEAD)
        centers_t, counter_t = init_target_centers(
            net,
            target,
            target_records,
            data_dict_target,
            batch_size=batch_size,
            num_workers=cfg.SYSTEM.NUM_WORKERS,
            beat_num=cfg.SETTING.BEAT_NUM,
            fixed_len=cfg.SETTING.FIXED_LEN,
            lead=cfg.SETTING.LEAD,
            thrs=thrs)
        centers_s_np = {}
        centers_t_np = {}
        for l in range(4):
            if l in centers_s.keys():
                centers_s_np['c{}'.format(
                    l)] = centers_s[l].detach().cpu().numpy()
            if l in centers_t.keys():
                centers_t_np['c{}'.format(
                    l)] = centers_t[l].detach().cpu().numpy()
        sio.savemat(centers_source_dir, centers_s_np)
        sio.savemat(centers_target_dir, centers_t_np)

    if flag_center:
        center_s = torch.from_numpy(
            sio.loadmat(center_source_dir)['c'].squeeze()).cuda()
        center_t = torch.from_numpy(
            sio.loadmat(center_target_dir)['c'].squeeze()).cuda()
    else:

        center_s = init_entire_center(net,
                                      source,
                                      source_records,
                                      data_dict_source,
                                      batch_size=batch_size,
                                      num_workers=cfg.SYSTEM.NUM_WORKERS,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD)
        center_t = init_entire_center(net,
                                      target,
                                      target_records,
                                      data_dict_target,
                                      batch_size=batch_size,
                                      num_workers=cfg.SYSTEM.NUM_WORKERS,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD)
        sio.savemat(center_source_dir, {'c': center_s.detach().cpu().numpy()})
        sio.savemat(center_target_dir, {'c': center_t.detach().cpu().numpy()})

    print("Starting STAGE III: adaptation process")

    low_bound = max(2 * pre_train_epochs,
                    check_epoch) if cfg.SETTING.RE_TRAIN else max(
                        pre_train_epochs, check_epoch)
    high_bound = 2 * pre_train_epochs + epochs if cfg.SETTING.RE_TRAIN else pre_train_epochs + epochs
    for epoch in range(low_bound, high_bound):
        best_f1_s = 0.0
        dataset.shuffle_target()

        # load_dir = osp.join(save_path, 'best_model.pkl')
        # loaded_models = torch.load(load_dir)['model_state_dict']
        loaded_models = []
        for idx in range(emsemble_num):
            load_dir = osp.join(
                save_path, '{}.pkl'.format(epoch - idx * emsemble_step - 1))
            loaded_models.append(torch.load(load_dir)['model_state_dict'])

        for idb, data_batch in enumerate(dataloader):
            net.train()

            s_batch, sl_batch, t_batch, tl_batch = data_batch
            s_batch = s_batch.cuda()
            sl_batch = sl_batch.cuda()
            t_batch = t_batch.cuda()
            tl_batch = tl_batch.cuda()

            feat_s, preds_s = net(s_batch)
            feat_t, preds_t = net(t_batch)

            loss_cls = criterion_cls_4(preds_s, sl_batch)
            loss = loss_cls * w_cls

            # Add an entropy regularizer
            # p_softmax = nn.Softmax(dim=1)(preds_s)
            # loss -= get_entropy_loss(p_softmax, entropy_w)

            delta_s = center_s - torch.mean(feat_s, dim=0)
            delta_t = center_t - torch.mean(feat_t, dim=0)

            center_s = center_s - lr_c * delta_s
            center_t = center_t - lr_c * delta_t

            loss_mmd = criterion_dist(center_s, center_t)
            loss += loss_mmd * w_mmd

            loss_intra = 0
            loss_inter = 0
            loss_ct = 0
            loss_cs = 0
            loss_cst = 0

            if flag_norm:
                if cfg.SETTING.ALIGN_SET == 'soft':
                    loss += get_L2norm_loss_self_driven(feat_s, w_norm)
                    loss += get_L2norm_loss_self_driven(feat_t, w_norm)
                else:
                    loss += get_L2norm_loss_self_driven_hard(
                        feat_s, cfg.PARAMETERS.RADIUS, w_norm)
                    loss += get_L2norm_loss_self_driven_hard(
                        feat_t, cfg.PARAMETERS.RADIUS, w_norm)
            '''Obtaining the pesudo labels of target samples'''
            pseudo_label_nums = {0: 0, 1: 0, 2: 0, 3: 0}
            pseudo_labels, legal_indices = obtain_pseudo_labels(
                teacher_net, loaded_models, t_batch, thrs)
            # pseudo_labels: (NUM, ); legal_indices: (NUM, ),the indices of legal pseudo labels;

            tmp_centers_t = {}
            tmp_feats_t = {}
            # if len(pesudo_labels):
            if pseudo_labels.size(0) > 0:
                # feat_t_pesudo = torch.index_select(feat_t, dim=0, index=torch.LongTensor(legal_indices).cuda())
                feat_t_pseudo = torch.index_select(feat_t,
                                                   dim=0,
                                                   index=legal_indices)

                for l in range(4):
                    # _index = np.argwhere(pseudo_labels == l)
                    _index = torch.nonzero(pseudo_labels == l).squeeze(dim=1)
                    if _index.size(0) > 0:
                        pseudo_label_nums[l] = _index.size(0)
                        # _index = np.squeeze(_index, axis=1)
                        # _feat_t = torch.index_select(feat_t_pesudo, dim=0, index=torch.LongTensor(_index).cuda())
                        _feat_t = torch.index_select(feat_t_pseudo,
                                                     dim=0,
                                                     index=_index)
                        tmp_feats_t[l] = _feat_t
                        bs_ = _feat_t.size(0)

                        local_centers_tl = torch.mean(_feat_t, dim=0)
                        tmp_centers_t[l] = local_centers_tl

                        if l in centers_t.keys():
                            delta_ct = centers_t[l] - local_centers_tl
                            centers_t[l] = centers_t[l] - lr_ct * delta_ct
                            loss_ct_l = criterion_dist(local_centers_tl,
                                                       centers_t[l])
                            loss_ct += loss_ct_l
                        else:
                            centers_t[l] = local_centers_tl

                        if flag_intra:
                            m_feat_t = centers_t[l].repeat((bs_, 1))
                            loss_intra_l = criterion_dist(_feat_t,
                                                          m_feat_t,
                                                          dim=1)
                            loss_intra += loss_intra_l

            if cfg.SETTING.CLoss:
                loss += loss_ct * w_ct

            # sl_batch_np = sl_batch.detach().cpu().numpy()
            sl_batch_ = sl_batch.detach()
            true_label_nums = {0: 0, 1: 0, 2: 0, 3: 0}

            tmp_centers_s = {}
            tmp_feats_s = {}

            for l in range(4):
                # _index = np.argwhere(sl_batch_np == l)
                _index = torch.nonzero(sl_batch_ == l).squeeze(dim=1)
                if _index.size(0) > 0:
                    true_label_nums[l] = _index.size(0)
                    # _feat_s = torch.index_select(feat_s, dim=0, index=torch.LongTensor(_index).cuda())
                    _feat_s = torch.index_select(feat_s, dim=0, index=_index)
                    tmp_feats_s[l] = _feat_s
                    bs_ = _feat_s.size(0)

                    local_centers_sl = torch.mean(_feat_s, dim=0)
                    tmp_centers_s[l] = local_centers_sl
                    delta_cs = centers_s[l] - local_centers_sl
                    centers_s[l] = centers_s[l] - lr_cs * delta_cs

                    loss_cs_l = criterion_dist(local_centers_sl, centers_s[l])
                    loss_cs += loss_cs_l

                    if flag_intra:
                        m_feat_s = centers_s[l].repeat((bs_, 1))
                        loss_intra_l = criterion_dist(_feat_s, m_feat_s, dim=1)
                        loss_intra += loss_intra_l

            if cfg.SETTING.CLoss:
                loss += loss_cs * w_cs

            for l in centers_t.keys():
                loss_cst_l = criterion_dist(centers_s[l], centers_t[l])
                loss_cst += loss_cst_l

            if cfg.SETTING.CLoss:
                loss += loss_cst * w_cst

            for i in range(4 - 1):
                for j in range(i + 1, 4):
                    loss_inter_ij_s = torch.max(
                        thr_m - criterion_dist(centers_s[i], centers_s[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()
                    loss_inter_ij_t = torch.max(
                        thr_m - criterion_dist(centers_t[i], centers_t[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()
                    '''Add items between two domains'''
                    loss_inter_ij_st = torch.max(
                        thr_m - criterion_dist(centers_s[i], centers_t[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()
                    loss_inter_ij_ts = torch.max(
                        thr_m - criterion_dist(centers_t[i], centers_s[j]),
                        torch.FloatTensor([0]).cuda()).squeeze()

                    loss_inter_ij = (loss_inter_ij_s + loss_inter_ij_t +
                                     loss_inter_ij_st + loss_inter_ij_ts) / 4
                    # loss_inter_ij = (loss_inter_ij_s + loss_inter_ij_t)
                    # loss_inter_ij = loss_inter_ij_s

                    loss_inter += loss_inter_ij

            if flag_inter:
                loss += loss_inter * w_inter
            if flag_intra:
                loss += loss_intra * w_intra

            loss_coral = 0
            if cfg.SETTING.CORAL:
                for l in tmp_feats_s.keys():
                    if l in tmp_feats_t.keys():
                        loss_coral += coral(tmp_feats_s[l], tmp_feats_t[l])
                loss += loss_coral

            optimizer_main.zero_grad()
            loss.backward(retain_graph=True)
            optimizer_main.step()
            scheduler_main.step()
            if args.use_ema:
                ema.update()
                ema.apply_shadow()

            running_lr = optimizer_main.state_dict()['param_groups'][0]['lr']
            torch.cuda.empty_cache()
            for l in centers_s.keys():
                centers_s[l] = centers_s[l].detach()
            for l in centers_t.keys():
                centers_t[l] = centers_t[l].detach()
            center_s = center_s.detach()
            center_t = center_t.detach()

            if idb == iter_num - 1:
                print("[{}, {}] cls loss: {:.4f}, cs loss: {:.4f}, "
                      "ct loss: {:.4f}, cst loss: {:.4f}, mmd loss: {:.4f}, "
                      "inter loss: {:.4f}, intra loss: {:.4f}, "
                      "CORAL: {:.4f}, "
                      "lr: {:.5f}".format(epoch, idb, loss_cls, loss_cs,
                                          loss_ct, loss_cst, loss_mmd,
                                          loss_inter, loss_intra, loss_coral,
                                          running_lr))

                print("The number of pesudo labels and true labels:")
                pprint.pprint(pseudo_label_nums)
                pprint.pprint(true_label_nums)

                torch.save({'model_state_dict': net.state_dict()},
                           osp.join(save_path, '{}.pkl'.format(epoch)))

        if epoch % 10 == 9:
            net.eval()

            preds_entire = []
            labels_entire = []

            with torch.no_grad():
                for idb, data_batch in enumerate(dloader_val):
                    s_batch, l_batch = data_batch

                    s_batch = s_batch.cuda()
                    l_batch = l_batch.numpy()

                    _, logits = net(s_batch)

                    preds_softmax = F.log_softmax(logits, dim=1).exp()
                    preds_softmax_np = preds_softmax.detach().cpu().numpy()
                    preds = np.argmax(preds_softmax_np, axis=1)

                    preds_entire.append(preds)
                    labels_entire.append(l_batch)

                    torch.cuda.empty_cache()

            preds_entire = np.concatenate(preds_entire, axis=0)
            labels_entire = np.concatenate(labels_entire, axis=0)

            Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                               y_label=labels_entire)
            results = evaluator._metrics(predictions=preds_entire,
                                         labels=labels_entire)
            f1_scores = evaluator._f1_score(y_pred=preds_entire,
                                            y_true=labels_entire)
            con_matrix = evaluator._confusion_matrix(preds_entire,
                                                     labels_entire)

            print('The overall accuracy is: {}'.format(results['Acc']))
            print("The confusion matrix is: ")
            print(con_matrix)
            # print('The sklearn metrics are: ')
            print('Pp: ')
            pprint.pprint(Pp)
            print('Se: ')
            pprint.pprint(Se)
            print('The F1 score is: {}'.format(f1_scores))

            if f1_scores[2] >= best_f1_s:
                best_f1_s = f1_scores[2]
                torch.save({"model_state_dict": net.state_dict()},
                           osp.join(save_path, 'best_model.pkl'))

        # Updating thresholds for pesudo labels
        if cfg.SETTING.INCRE_THRS:
            if cfg.SETTING.RE_TRAIN:
                epoch_ = epoch - 2 * pre_train_epochs
            else:
                epoch_ = epoch - pre_train_epochs
            thrs = update_thrs(thrs, epoch_, epochs)
Пример #7
0
        epoch_result = {
            'epoch': epoch,
            'train_result': train_result,
            'val_result': val_result
        }
        pickle.dump(epoch_result, open(os.path.join(results_dir, 'result_epoch_{0}.p'.format(epoch)), 'wb'))

        # output_path_base = os.path.basename(output_path)
        # os.system('aws s3 sync /root/bengali_data/{0} s3://eaitest1/{1}'.format(output_path_base, output_path_base))
        # os.system('rm -r /root/bengali_data/{0}/model_backups'.format(output_path_base))
        # os.system('mkdir /root/bengali_data/{0}/model_backups'.format(output_path_base))

        
        # Add model to Tensorboard to inspect the details of the architecture
        writer_tensorboard.add_graph(model, input_data)
        writer_tensorboard.close()


if __name__ == '__main__':

    arguments = docopt(__doc__, argv=None, help=True, version=None, options_first=False)
    output_path = arguments['-path_output']
    data_path = arguments['--path_cfg_data']
    cfg_path = arguments['--path_cfg_override']
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_path)
    if cfg_path is not None:
        cfg.merge_from_file(cfg_path)
    cfg.OUTPUT_PATH = output_path
    train(cfg)
Пример #8
0
def train(args):

    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()
    '''Setting the random seed used in the experiment'''

    if cfg.SETTING.SEED != -1:
        torch.manual_seed(cfg.SETTING.SEED)
        torch.cuda.manual_seed(cfg.SETTING.SEED)

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET

    batch_size = cfg.TRAIN.BATCH_SIZE
    pre_train_epochs = cfg.TRAIN.PRE_TRAIN_EPOCHS
    epochs = cfg.TRAIN.EPOCHS
    lr = cfg.TRAIN.LR
    decay_rate = cfg.TRAIN.DECAY_RATE
    decay_step = cfg.TRAIN.DECAY_STEP
    flag_intra = cfg.SETTING.INTRA_LOSS
    flag_inter = cfg.SETTING.INTER_LOSS
    flag_norm = cfg.SETTING.NORM_ALIGN
    optimizer_ = cfg.SETTING.OPTIMIZER

    w_l2 = cfg.PARAMETERS.W_L2
    w_cls = cfg.PARAMETERS.W_CLS
    w_norm = cfg.PARAMETERS.W_NORM
    w_cs = cfg.PARAMETERS.BETA1
    w_ct = cfg.PARAMETERS.BETA2
    w_cst = cfg.PARAMETERS.BETA
    w_mmd = cfg.PARAMETERS.BETA_MMD
    w_inter = cfg.PARAMETERS.BETA_INTER
    w_intra = cfg.PARAMETERS.BETA_INTRA
    thr_m = cfg.PARAMETERS.THR_M
    thrs_ = cfg.PARAMETERS.THRS
    entropy_w = 0.001

    emsemble_num = cfg.PARAMETERS.EMSEMBLE_NUM
    emsemble_step = cfg.PARAMETERS.EMSEMBLE_STEP

    lr_c = cfg.PARAMETERS.LR_C
    lr_cs = cfg.PARAMETERS.LR_C_S
    lr_ct = cfg.PARAMETERS.LR_C_T

    thrs = {}
    for l in range(len(thrs_)):
        thrs[l] = thrs_[l]

    exp_id = os.path.basename(cfg_dir).split('.')[0]

    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)
    if not osp.exists(save_path):
        os.makedirs(save_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))
    flag_loading = True if osp.exists(check_point_dir) else False

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target) if (
        source != target) else data_dict_source

    transform = augmentation_transform_with_rr if cfg.SETTING.AUGMENTATION else None

    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    dataset = UDA_DATASET(source,
                          target,
                          data_dict_source,
                          data_dict_target,
                          source_records,
                          target_records,
                          cfg.SETTING.UDA_NUM,
                          load_beat_with_rr,
                          transform=transform,
                          beat_num=cfg.SETTING.BEAT_NUM,
                          fixed_len=cfg.SETTING.FIXED_LEN,
                          lead=cfg.SETTING.LEAD)

    dset_val = MULTI_ECG_EVAL_DATASET(target,
                                      load_beat_with_rr,
                                      data_dict_target,
                                      test_records=target_records,
                                      beat_num=cfg.SETTING.BEAT_NUM,
                                      fixed_len=cfg.SETTING.FIXED_LEN,
                                      lead=cfg.SETTING.LEAD,
                                      unlabel_num=0)
    dloader_val = DataLoader(dset_val,
                             batch_size=batch_size,
                             num_workers=cfg.SYSTEM.NUM_WORKERS)

    if cfg.TRAIN.IMBALANCE_SAMPLE:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                sampler=UDAImbalancedDatasetSampler(dataset))
    else:
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=cfg.SYSTEM.NUM_WORKERS,
                                shuffle=True)

    iter_num = int(len(dataset) / batch_size)

    net = build_acnn_models(cfg.SETTING.NETWORK,
                            cfg.SETTING.ASPP_BN,
                            cfg.SETTING.ASPP_ACT,
                            cfg.SETTING.LEAD,
                            cfg.PARAMETERS.P,
                            cfg.SETTING.DILATIONS,
                            act_func=cfg.SETTING.ACT,
                            f_act_func=cfg.SETTING.F_ACT,
                            apply_residual=cfg.SETTING.RESIDUAL,
                            bank_num=cfg.SETTING.BANK_NUM)
    # Initialization of the model
    net.apply(init_weights)

    teacher_net = build_acnn_models(cfg.SETTING.NETWORK,
                                    cfg.SETTING.ASPP_BN,
                                    cfg.SETTING.ASPP_ACT,
                                    cfg.SETTING.LEAD,
                                    cfg.PARAMETERS.P,
                                    cfg.SETTING.DILATIONS,
                                    act_func=cfg.SETTING.ACT,
                                    f_act_func=cfg.SETTING.F_ACT,
                                    apply_residual=cfg.SETTING.RESIDUAL,
                                    bank_num=cfg.SETTING.BANK_NUM)

    print("The network {} has {} parameters in total".format(
        cfg.SETTING.NETWORK, sum(x.numel() for x in net.parameters())))

    if flag_loading:
        net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
        print("The saved model is loaded.")
    net = net.cuda()

    criterion_cls_4 = loss_function(cfg.SETTING.LOSS,
                                    dataset=source,
                                    num_ew=cfg.PARAMETERS.N,
                                    T=cfg.PARAMETERS.T)
    criterion_dist = build_distance(cfg.SETTING.DISTANCE)

    optimizer_pre = get_optimizer(optimizer_, net.parameters(), lr, w_l2)
    scheduler_pre = optim.lr_scheduler.StepLR(optimizer_pre,
                                              step_size=decay_step,
                                              gamma=decay_rate)

    optimizer_main = get_optimizer(optimizer_, net.parameters(), lr * 0.1,
                                   w_l2)
    scheduler_main = optim.lr_scheduler.StepLR(optimizer_main,
                                               step_size=decay_step * 10,
                                               gamma=decay_rate)
    evaluator = Eval(num_class=4)
    '''Initial and register the EMA'''
    ema = EMA(model=net, decay=0.99)
    ema.register()

    if check_epoch < pre_train_epochs - 1:
        print("Starting STAGE I: pre-training the model using source data")

        best_f1_s = 0.0

        for epoch in range(max(0, check_epoch), pre_train_epochs):
            for idb, data_batch in enumerate(dataloader):
                net.train()

                s_batch, sl_batch, t_batch, tl_batch = data_batch
                s_batch = s_batch.cuda()
                sl_batch = sl_batch.cuda()
                t_batch = t_batch.cuda()
                tl_batch = tl_batch.cuda()

                _, preds = net(s_batch)
                loss = criterion_cls_4(preds, sl_batch)

                # Add an entropy regularizer
                # p_softmax = nn.Softmax(dim=1)(preds)
                # loss -= get_entropy_loss(p_softmax, entropy_w)

                optimizer_pre.zero_grad()
                loss.backward()
                optimizer_pre.step()
                scheduler_pre.step()
                if args.use_ema:
                    ema.update()
                    ema.apply_shadow()

                running_lr = optimizer_pre.state_dict(
                )['param_groups'][0]['lr']

                print("[{}, {}] cls loss: {:.4f}, lr: {:.4f}".format(
                    epoch, idb, loss, running_lr),
                      end='\r')
                if idb == iter_num - 1:
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, '{}.pkl'.format(epoch)))

            if epoch % 10 == 9:
                net.eval()
                preds_entire = []
                labels_entire = []

                with torch.no_grad():
                    for idb, data_batch in enumerate(dloader_val):
                        s_batch, l_batch = data_batch

                        s_batch = s_batch.cuda()
                        l_batch = l_batch.numpy()

                        _, logits = net(s_batch)

                        preds_softmax = F.log_softmax(logits, dim=1).exp()
                        preds_softmax_np = preds_softmax.detach().cpu().numpy()
                        preds = np.argmax(preds_softmax_np, axis=1)

                        preds_entire.append(preds)
                        labels_entire.append(l_batch)

                        torch.cuda.empty_cache()

                preds_entire = np.concatenate(preds_entire, axis=0)
                labels_entire = np.concatenate(labels_entire, axis=0)

                Pp, Se = evaluator._sklean_metrics(y_pred=preds_entire,
                                                   y_label=labels_entire)
                results = evaluator._metrics(predictions=preds_entire,
                                             labels=labels_entire)
                # con_matrix = evaluator._confusion_matrix(y_pred=preds_entire,
                #                                          y_label=labels_entire)

                f1_scores = evaluator._f1_score(y_pred=preds_entire,
                                                y_true=labels_entire)

                print('The overall accuracy is: {}'.format(results['Acc']))
                print("The confusion matrix is: ")
                print('Pp: ')
                pprint.pprint(Pp)
                print('Se: ')
                pprint.pprint(Se)
                print('The F1 score is: {}'.format(f1_scores))

                if f1_scores[2] >= best_f1_s:
                    best_f1_s = f1_scores[2]
                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, 'best_model.pkl'))

                torch.cuda.empty_cache()
Пример #9
0
def train(args):

    cudnn.benchmark = True

    cfg_dir = args.config
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_dir)
    cfg.freeze()
    print(cfg)

    source = cfg.SETTING.TRAIN_DATASET
    target = cfg.SETTING.TEST_DATASET

    batch_size = cfg.TRAIN.BATCH_SIZE
    pre_train_epochs = cfg.TRAIN.PRE_TRAIN_EPOCHS
    epochs = cfg.TRAIN.EPOCHS
    lr = cfg.TRAIN.LR
    decay_rate = cfg.TRAIN.DECAY_RATE
    decay_step = cfg.TRAIN.DECAY_STEP
    flag_c = cfg.SETTING.CENTER
    flag_intra = cfg.SETTING.INTRA_LOSS
    flag_inter = cfg.SETTING.INTER_LOSS
    flag_norm = cfg.SETTING.NORM_ALIGN
    optimizer_ = cfg.SETTING.OPTIMIZER

    w_l2 = cfg.PARAMETERS.W_L2
    w_cls = cfg.PARAMETERS.W_CLS
    w_norm = cfg.PARAMETERS.W_NORM
    w_c = cfg.PARAMETERS.BETA_C
    w_cs = cfg.PARAMETERS.BETA1
    w_ct = cfg.PARAMETERS.BETA2
    w_cst = cfg.PARAMETERS.BETA
    w_bin = cfg.PARAMETERS.W_BIN
    w_mmd = cfg.PARAMETERS.BETA_MMD
    w_inter = cfg.PARAMETERS.BETA_INTER
    w_intra = cfg.PARAMETERS.BETA_INTRA
    thr_m = cfg.PARAMETERS.THR_M
    thrs_ = cfg.PARAMETERS.THRS

    emsemble_num = cfg.PARAMETERS.EMSEMBLE_NUM
    emsemble_step = cfg.PARAMETERS.EMSEMBLE_STEP

    lr_c = cfg.PARAMETERS.LR_C
    lr_cs = cfg.PARAMETERS.LR_C_S
    lr_ct = cfg.PARAMETERS.LR_C_T
    beta_cb = cfg.PARAMETERS.BETA_CB

    weights = get_cb_weights(source, beta_cb)

    thrs = {}
    for l in range(len(thrs_)):
        thrs[l] = thrs_[l]

    exp_id = os.path.basename(cfg_dir).split('.')[0]
    save_path = os.path.join(cfg.SYSTEM.SAVE_PATH, exp_id)
    if not osp.exists(save_path):
        os.makedirs(save_path)

    check_epoch = args.check_epoch
    check_point_dir = osp.join(save_path, '{}.pkl'.format(check_epoch))
    flag_loading = True if osp.exists(check_point_dir) else False

    data_dict_source = load_dataset_to_memory(source)
    data_dict_target = load_dataset_to_memory(target) if (source != target) else data_dict_source

    transform = augmentation_transform_with_rr if cfg.SETTING.AUGMENTATION else None

    source_records = build_training_records(source)
    target_records = build_validation_records(target)

    dataset = UDA_DATASET(source, target,
                          data_dict_source, data_dict_target,
                          source_records, target_records,
                          cfg.SETTING.UDA_NUM,
                          load_beat_with_rr,
                          transform=transform,
                          beat_num=cfg.SETTING.BEAT_NUM,
                          fixed_len=cfg.SETTING.FIXED_LEN,
                          use_dbscan=cfg.SETTING.USE_DBSCAN)

    dataloader = DataLoader(dataset, batch_size=batch_size,
                            num_workers=cfg.SYSTEM.NUM_WORKERS,
                            sampler=UDAImbalancedDatasetSampler(dataset))

    iter_num = int(len(dataset) / batch_size)

    net = build_cnn_models(cfg.SETTING.NETWORK,
                           fixed_len=cfg.SETTING.FIXED_LEN,
                           p=cfg.PARAMETERS.P)
    # Initialization of the model
    net.apply(init_weights)

    teacher_net = build_cnn_models(cfg.SETTING.NETWORK,
                                   fixed_len=cfg.SETTING.FIXED_LEN,
                                   p=cfg.PARAMETERS.P)

    print("The network {} has {} parameters in total".format(cfg.SETTING.NETWORK,
                                                             sum(x.numel() for x in net.parameters())))

    if flag_loading:
        net.load_state_dict(torch.load(check_point_dir)['model_state_dict'])
    net = net.cuda()

    criterion_cls_4 = loss_function(cfg.SETTING.LOSS, dataset=source, num_ew=cfg.PARAMETERS.N)
    criterion_cls_2 = loss_function('BinCBLoss', dataset=source)
    criterion_dist = build_distance(cfg.SETTING.DISTANCE)

    optimizer_pre = get_optimizer(optimizer_, net.parameters(), lr, w_l2)
    scheduler_pre = optim.lr_scheduler.StepLR(optimizer_pre,
                                              step_size=decay_step,
                                              gamma=decay_rate)

    optimizer_re = get_optimizer(optimizer_, net.parameters(), lr * 0.1, w_l2)
    scheduler_re = optim.lr_scheduler.StepLR(optimizer_re,
                                             step_size=decay_step * 10,
                                             gamma=decay_rate)

    optimizer_main = get_optimizer(optimizer_, net.parameters(), lr * 0.1, w_l2)
    scheduler_main = optim.lr_scheduler.StepLR(optimizer_main,
                                               step_size=decay_step * 10,
                                               gamma=decay_rate)
    evaluator = Eval(num_class=4)

    if check_epoch <= pre_train_epochs - 1:
        print("Starting STAGE I: pre-training the model using source data")

        for epoch in range(max(0, check_epoch), pre_train_epochs):
            for idb, data_batch in enumerate(dataloader):
                net.train()

                s_batch, sl_batch, sr_batch, sb_batch, \
                t_batch, tl_batch, tr_batch, tb_batch = data_batch

                s_batch = s_batch.unsqueeze(dim=1)
                t_batch = t_batch.unsqueeze(dim=1)
                s_batch = s_batch.cuda()
                sl_batch = sl_batch.cuda()
                t_batch = t_batch.cuda()
                tl_batch = tl_batch.cuda()
                sr_batch = sr_batch.cuda()
                sb_batch = sb_batch.cuda()
                tr_batch = tr_batch.cuda()
                tb_batch = tb_batch.cuda()

                _, pef, _, preds = net(s_batch, sr_batch)

                cls_loss = criterion_cls_4(preds, sl_batch)
                bin_loss = criterion_cls_2(pef, sb_batch)
                loss = cls_loss * w_cls + bin_loss * w_bin

                optimizer_pre.zero_grad()
                loss.backward()
                optimizer_pre.step()
                scheduler_pre.step()

                running_lr = optimizer_pre.state_dict()['param_groups'][0]['lr']

                if idb % 10 == 9:
                    print("[{}, {}] cls loss: {:.4f}, lr: {:.4f}".format(
                        epoch, idb, cls_loss, running_lr
                    ))

                    torch.save({"model_state_dict": net.state_dict()},
                               osp.join(save_path, '{}.pkl'.format(epoch)))

                if idb == iter_num - 1:
                    net.eval()
                    _, _, _, preds = net(t_batch, tr_batch)
                    preds_softmax = F.log_softmax(preds, dim=1).exp()
                    preds_softmax_np = preds_softmax.detach().cpu().numpy()
                    preds_ = np.argmax(preds_softmax_np, axis=1)

                    loss_eval = criterion_cls_4(preds, tl_batch)
                    print("The loss on target mini-batch is {:.4f}".format(loss_eval))
                    results = evaluator._metrics(predictions=preds_,
                                                 labels=tl_batch.detach().cpu().numpy())
                    pprint.pprint(results)

                torch.cuda.empty_cache()

    net.eval()
    centers_s, _ = init_source_centers(net, source, source_records, data_dict_source,
                                       batch_size=batch_size, num_workers=cfg.SYSTEM.NUM_WORKERS,
                                       beat_num=cfg.SETTING.BEAT_NUM, fixed_len=cfg.SETTING.FIXED_LEN)

    if cfg.SETTING.RE_TRAIN and (check_epoch <= pre_train_epochs * 2 - 1):
        print("Starting STAGE II: re-training the model using source data and extra constraints")

        for epoch in range(max(pre_train_epochs, check_epoch), 2 * pre_train_epochs):
            for idb, data_batch in enumerate(dataloader):
                net.train()

                s_batch, sl_batch, sr_batch, sb_batch,\
                t_batch, tl_batch, tr_batch, tb_batch = data_batch

                s_batch = s_batch.unsqueeze(dim=1)
                t_batch = t_batch.unsqueeze(dim=1)
                s_batch = s_batch.cuda()
                sl_batch = sl_batch.cuda()
                t_batch = t_batch.cuda()
                tl_batch = tl_batch.cuda()
                sr_batch = sr_batch.cuda()
                sb_batch = sb_batch.cuda()
                tr_batch = tr_batch.cuda()
                tb_batch = tb_batch.cuda()

                _, pef_s, feat_s, preds = net(s_batch, sr_batch)

                loss_cls = criterion_cls_4(preds, sl_batch)
                loss_bin = criterion_cls_2(pef_s, sb_batch)
                loss = loss_cls * w_cls + loss_bin * w_bin

                loss_cs = 0
                loss_intra = 0

                sl_batch_np = sl_batch.detach().cpu().numpy()

                for l in range(4):
                    _index = np.argwhere(sl_batch_np == l)
                    if len(_index):
                        _index = np.squeeze(_index, axis=1)
                        _feat_s = torch.index_select(feat_s, dim=0, index=torch.LongTensor(_index).cuda())
                        bs_ = _feat_s.size()[0]
                        m_feat_s = torch.mean(_feat_s, dim=0)

                        delta_cs_l = centers_s[l] - m_feat_s
                        centers_s[l] = centers_s[l] - lr_cs * delta_cs_l

                        loss_cs_l = criterion_dist(m_feat_s, centers_s[l])
                        loss_cs += loss_cs_l

                        if flag_intra:
                            cl_feat_s = centers_s[l].repeat((bs_, 1))
                            loss_intra_l = criterion_dist(_feat_s, cl_feat_s, dim=1) / bs_
                            loss_intra += loss_intra_l

                loss_intra = loss_intra / 4
                loss += loss_cs * w_cs

                loss_inter = 0
                for i in range(4 - 1):
                    for j in range(i + 1, 4):
                        loss_inter_ij = torch.max(thr_m - criterion_dist(centers_s[i], centers_s[j]),
                                                  torch.FloatTensor([0]).cuda()).squeeze()
                        loss_inter += loss_inter_ij
                loss_inter = loss_inter / 6

                if flag_inter:
                    loss += loss_inter * w_inter
                if flag_intra:
                    loss += loss_intra * w_intra

                optimizer_re.zero_grad()
                loss.backward(retain_graph=True)
                optimizer_re.step()
                scheduler_re.step()

                running_lr = optimizer_pre.state_dict()['param_groups'][0]['lr']

                if idb % 10 == 9:
                    print("[{}, {}] cls loss: {:.4f}, cs loss: {:.4f}, intra loss: {:.4f}, "
                          "inter_loss: {:.4f}, lr: {:.4f}".format(epoch, idb, loss_cls, loss_cs,
                                                                  loss_intra, loss_inter, running_lr))
                    torch.save({'model_state_dict': net.state_dict()},
                               osp.join(save_path, '{}.pkl'.format(epoch)))

                if idb == iter_num - 1:
                    net.eval()
                    _, _, _, preds = net(t_batch, tr_batch)
                    preds_softmax = F.log_softmax(preds, dim=1).exp()
                    preds_softmax_np = preds_softmax.detach().cpu().numpy()
                    preds_ = np.argmax(preds_softmax_np, axis=1)

                    loss_eval = criterion_cls_4(preds, tl_batch)
                    print("The loss on target mini-batch is: {:.4f}".format(loss_eval))
                    results = evaluator._metrics(predictions=preds_,
                                                 labels=tl_batch.detach().cpu().numpy())
                    pprint.pprint(results)

                for l in range(4):
                    centers_s[l] = centers_s[l].detach()

                torch.cuda.empty_cache()

    print('Start obtaining centers of each cluster and distribution')

    centers_source_dir = osp.join(save_path, "centers_source.mat")
    centers_target_dir = osp.join(save_path, "centers_target.mat")
    centers_dir = osp.join(save_path, "centers.mat")
    flag_centers = osp.exists(centers_source_dir) and osp.exists(centers_target_dir) and osp.exists(centers_dir)

    center_source_dir = osp.join(save_path, "center_s.mat")
    center_target_dir = osp.join(save_path, "center_t.mat")
    flag_center = osp.exists(center_source_dir) and osp.exists(center_target_dir)

    if flag_centers:
        centers_s = sio.loadmat(centers_source_dir)
        centers_t = sio.loadmat(centers_target_dir)
        centers = sio.loadmat(centers_dir)
        for l in range(4):
            centers_s[l] = torch.from_numpy(centers_s['c{}'.format(l)].squeeze()).cuda()
            centers_t[l] = torch.from_numpy(centers_t['c{}'.format(l)].squeeze()).cuda()
            centers[l] = torch.from_numpy(centers['c{}'.format(l)].squeeze().astype(np.float32)).cuda()

    else:
        net.eval()
        centers_s, counter_s = init_source_centers(net, source, source_records, data_dict_source,
                                                   batch_size=batch_size, num_workers=cfg.SYSTEM.NUM_WORKERS,
                                                   beat_num=cfg.SETTING.BEAT_NUM, fixed_len=cfg.SETTING.FIXED_LEN)
        centers_t, counter_t = init_target_centers(net, target, target_records, data_dict_target,
                                                   batch_size=batch_size, num_workers=cfg.SYSTEM.NUM_WORKERS,
                                                   beat_num=cfg.SETTING.BEAT_NUM, fixed_len=cfg.SETTING.FIXED_LEN,
                                                   thrs=thrs)
        centers_s_np = {}
        centers_t_np = {}
        centers_np = {}
        centers = {}
        for l in range(4):
            centers_s_np['c{}'.format(l)] = centers_s[l].detach().cpu().numpy()
            centers_t_np['c{}'.format(l)] = centers_t[l].detach().cpu().numpy()
            centers_np['c{}'.format(l)] = ((centers_s_np['c{}'.format(l)] * counter_s[l] +
                                            centers_t_np['c{}'.format(l)] * counter_t[l])
                                           / (counter_s[l] + counter_t[l])).astype(np.float32)
            centers[l] = torch.from_numpy(centers_np['c{}'.format(l)]).cuda()
        sio.savemat(centers_source_dir, centers_s_np)
        sio.savemat(centers_target_dir, centers_t_np)
        sio.savemat(centers_dir, centers_np)

    if flag_center:
        center_s = torch.from_numpy(sio.loadmat(center_source_dir)['c'].squeeze()).cuda()
        center_t = torch.from_numpy(sio.loadmat(center_target_dir)['c'].squeeze()).cuda()
    else:

        center_s = init_entire_center(net, source, source_records, data_dict_source,
                                      batch_size=batch_size, num_workers=cfg.SYSTEM.NUM_WORKERS,
                                      beat_num=cfg.SETTING.BEAT_NUM, fixed_len=cfg.SETTING.FIXED_LEN)
        center_t = init_entire_center(net, target, target_records, data_dict_target,
                                      batch_size=batch_size, num_workers=cfg.SYSTEM.NUM_WORKERS,
                                      beat_num=cfg.SETTING.BEAT_NUM, fixed_len=cfg.SETTING.FIXED_LEN)

        sio.savemat(center_source_dir, {'c': center_s.detach().cpu().numpy()})
        sio.savemat(center_target_dir, {'c': center_t.detach().cpu().numpy()})

    print("Starting STAGE III: adaptation process")

    low_bound = max(2 * pre_train_epochs, check_epoch) if cfg.SETTING.RE_TRAIN else max(pre_train_epochs, check_epoch)
    high_bound = 2 * pre_train_epochs + epochs if cfg.SETTING.RE_TRAIN else pre_train_epochs + epochs
    for epoch in range(low_bound, high_bound):
        dataset.shuffle_target()

        loaded_models = []
        for idx in range(emsemble_num):
            load_dir = osp.join(save_path, '{}.pkl'.format(epoch - idx * emsemble_step - 1))
            loaded_models.append(torch.load(load_dir)['model_state_dict'])

        for idb, data_batch in enumerate(dataloader):
            net.train()

            s_batch, sl_batch, sr_batch, sb_batch,\
            t_batch, tl_batch, tr_batch, tb_batch = data_batch

            s_batch = s_batch.unsqueeze(dim=1)
            t_batch = t_batch.unsqueeze(dim=1)
            s_batch = s_batch.cuda()
            sl_batch = sl_batch.cuda()
            t_batch = t_batch.cuda()
            tl_batch = tl_batch.cuda()
            sr_batch = sr_batch.cuda()
            sb_batch = sb_batch.cuda()
            tr_batch = tr_batch.cuda()
            tb_batch = tb_batch.cuda()

            _, pef_s, feat_s, preds_s = net(s_batch, sr_batch)
            _, _, feat_t, preds_t = net(t_batch, tr_batch)

            loss_cls = criterion_cls_4(preds_s, sl_batch)
            loss_bin = criterion_cls_2(pef_s, sb_batch)
            loss = loss_cls * w_cls + loss_bin * w_bin

            delta_s = center_s - torch.mean(feat_s, dim=0)
            delta_t = center_t - torch.mean(feat_t, dim=0)

            center_s = center_s - lr_c * delta_s
            center_t = center_t - lr_c * delta_t

            loss_mmd = criterion_dist(center_s, center_t)
            loss += loss_mmd * w_mmd

            loss_intra = 0
            loss_inter = 0
            loss_ct = 0
            loss_cs = 0
            loss_cst = 0

            if flag_norm:
                if cfg.SETTING.ALIGN_SET == 'soft':
                    loss += get_L2norm_loss_self_driven(feat_s, w_norm)
                    loss += get_L2norm_loss_self_driven(feat_t, w_norm)
                else:
                    loss += get_L2norm_loss_self_driven_hard(feat_s, cfg.PARAMETERS.RADIUS, w_norm)
                    loss += get_L2norm_loss_self_driven_hard(feat_t, cfg.PARAMETERS.RADIUS, w_norm)

            pesudo_label_nums = {0: 0, 1: 0, 2: 0, 3: 0}
            pesudo_labels, legal_indices = obtain_pesudo_labels(teacher_net, loaded_models, t_batch, tr_batch, thrs)

            tmp_centers_t = {}
            if len(pesudo_labels):
                feat_t_pesudo = torch.index_select(feat_t, dim=0, index=torch.LongTensor(legal_indices).cuda())

                for l in range(4):
                    _index = np.argwhere(pesudo_labels == l)
                    if len(_index):
                        pesudo_label_nums[l] = len(_index)
                        _index = np.squeeze(_index, axis=1)
                        _feat_t = torch.index_select(feat_t_pesudo, dim=0, index=torch.LongTensor(_index).cuda())
                        bs_ = _feat_t.size()[0]

                        local_centers_tl = torch.mean(_feat_t, dim=0)
                        tmp_centers_t[l] = local_centers_tl
                        delta_ct = centers_t[l] - local_centers_tl
                        centers_t[l] = centers_t[l] - lr_ct * delta_ct

                        loss_ct_l = criterion_dist(local_centers_tl, centers_t[l])
                        loss_ct += loss_ct_l

                        if flag_intra:
                            m_feat_t = centers_t[l].repeat((bs_, 1))
                            loss_intra_l = criterion_dist(_feat_t, m_feat_t, dim=1)
                            loss_intra += loss_intra_l

            loss += loss_ct * w_ct

            sl_batch_np = sl_batch.detach().cpu().numpy()
            true_label_nums = {0: 0, 1: 0, 2: 0, 3: 0}

            tmp_centers_s = {}
            for l in range(4):
                _index = np.argwhere(sl_batch_np == l)
                if len(_index):
                    true_label_nums[l] = len(_index)
                    _index = np.squeeze(_index, axis=1)
                    _feat_s = torch.index_select(feat_s, dim=0, index=torch.LongTensor(_index).cuda())
                    bs_ = _feat_s.size()[0]
                    local_centers_sl = torch.mean(_feat_s, dim=0)
                    tmp_centers_s[l] = local_centers_sl
                    delta_cs = centers_s[l] - local_centers_sl
                    centers_s[l] = centers_s[l] - lr_cs * delta_cs

                    loss_cs_l = criterion_dist(local_centers_sl, centers_s[l])
                    loss_cs += loss_cs_l

                    if flag_intra:
                        m_feat_s = centers_s[l].repeat((bs_, 1))
                        loss_intra_l = criterion_dist(_feat_s, m_feat_s, dim=1)
                        loss_intra += loss_intra_l

            loss += loss_cs * w_cs

            for l in range(4):
                loss_cst_l = criterion_dist(centers_s[l], centers_t[l])
                loss_cst += loss_cst_l

            loss += loss_cst * w_cst

            for i in range(4 - 1):
                for j in range(i + 1, 4):
                    loss_inter_ij_s = torch.max(thr_m - criterion_dist(centers_s[i], centers_s[j]),
                                                torch.FloatTensor([0]).cuda()).squeeze()
                    loss_inter_ij_t = torch.max(thr_m - criterion_dist(centers_t[i], centers_t[j]),
                                                torch.FloatTensor([0]).cuda()).squeeze()
                    loss_inter_ij = (loss_inter_ij_s + loss_inter_ij_t)
                    loss_inter += loss_inter_ij

            if flag_inter:
                loss += loss_inter * w_inter
            if flag_intra:
                loss += loss_intra * w_intra

            loss_c = 0

            for l in range(4):
                if (l in tmp_centers_s.keys()) and (l in tmp_centers_t.keys()):
                    tmp_centers_sl = tmp_centers_s[l]
                    tmp_centers_tl = tmp_centers_t[l]
                    m_centers_stl = (pesudo_label_nums[l] * tmp_centers_tl + true_label_nums[l] * tmp_centers_sl) \
                                    / (pesudo_label_nums[l] + true_label_nums[l])
                    delta_l = centers[l] - m_centers_stl
                    centers[l] = centers[l] - lr_c * delta_l

                    loss_cl = criterion_dist(m_centers_stl, centers[l])
                    loss_c += loss_cl

            if flag_c:
                loss += loss_c * w_c

            optimizer_main.zero_grad()
            loss.backward(retain_graph=True)
            optimizer_main.step()
            scheduler_main.step()

            running_lr = optimizer_main.state_dict()['param_groups'][0]['lr']
            torch.cuda.empty_cache()
            for l in range(4):
                centers[l] = centers[l].detach()
                centers_s[l] = centers_s[l].detach()
                centers_t[l] = centers_t[l].detach()
            center_s = center_s.detach()
            center_t = center_t.detach()

            if idb % 10 == 9:
                print("[{}, {}] cls loss: {:.4f}, cs loss: {:.4f}, "
                      "ct loss: {:.4f}, cst loss: {:.4f}, mmd loss: {:.4f}, "
                      "inter loss: {:.4f}, intra loss: {:.4f}, c loss: {:.4f}, "
                      "lr: {:.5f}".format(epoch, idb, loss_cls, loss_cs, loss_ct, loss_cst, loss_mmd,
                                          loss_inter, loss_intra, loss_c, running_lr))

                print("The number of pesudo labels and true labels:")
                pprint.pprint(pesudo_label_nums)
                pprint.pprint(true_label_nums)

                torch.save({'model_state_dict': net.state_dict()},
                           osp.join(save_path, '{}.pkl'.format(epoch)))

            if idb == iter_num - 1:
                net.eval()
                _, _, _, preds = net(t_batch, tr_batch)
                preds_softmax = F.log_softmax(preds, dim=1).exp()
                preds_softmax_np = preds_softmax.detach().cpu().numpy()
                preds_ = np.argmax(preds_softmax_np, axis=1)
                loss_eval = criterion_cls_4(preds, tl_batch)
                print('The loss on target mini-batch is: {:.4f}'.format(loss_eval))
                results = evaluator._metrics(predictions=preds_,
                                             labels=tl_batch.detach().cpu().numpy())
                pprint.pprint(results)

        # Updating thresholds for pesudo labels
        if cfg.SETTING.INCRE_THRS:
            if cfg.SETTING.RE_TRAIN:
                epoch_ = epoch - 2 * pre_train_epochs
            else:
                epoch_ = epoch - pre_train_epochs
            thrs = update_thrs(thrs, epoch_, epochs)