def train():
    dataloaders = list()
    if 'big_exam' in wandb.config.datasets:
        dataloaders.append(AbnormalBigExamLoader(wandb_config=wandb.config))
    if 'audicor_10s' in wandb.config.datasets:
        dataloaders.append(AbnormalAudicor10sLoader(wandb_config=wandb.config))

    g = BaseDataGenerator(dataloaders=dataloaders,
                          wandb_config=wandb.config,
                          preprocessing_fn=preprocessing)
    train_set, valid_set, test_set = g.get()
    big_data = np.load('big_signal_label_634_5_28.npy')

    hht_set_signal = np.load(
        './abnormal_detection/imf0_2_set_hht.npy')[:, 2000:8000]
    hht_set_signal = np.delete(hht_set_signal, 433, axis=0)
    hht_set_signal = np.delete(hht_set_signal, 525, axis=0)

    hht_set_signal_h2 = np.load('./abnormal_detection/h2_imf.npy')[:,
                                                                   2000:8000]

    amp = np.load('./abnormal_detection/set_hht_am.npy')[:, 2000:8000]
    amp = np.delete(amp, 433, axis=0)
    amp = np.delete(amp, 525, axis=0)

    amp_h2 = np.load('./abnormal_detection/h2_amp.npy')[:, 2000:8000]
    hht_set_signal = np.array([
        np.array(hht_set_signal),
        np.array(hht_set_signal_h2),
        np.array(amp),
        np.array(amp_h2)
    ])

    hht_set_signal = np.transpose(hht_set_signal, (1, 2, 0))
    #hht_set_signal=np.vstack((hht_set_signal,amp))
    print('hht_set_signal', hht_set_signal.shape)

    a = big_data[:, 1:3]
    b = np.zeros((big_data.shape[0], 1))
    for i in range(big_data.shape[0]):
        #print('a[i,0]',a[i,0])
        if (a[i, 0] == '1.0'):
            b[i] = 1
        if (a[i, 1] == '1.0'):
            b[i] = 1
    #print(a[:5,0],a[:5,1])
    #print(a[:5],a.shape)
    b = np.delete(b, 433, axis=0)
    b = np.delete(b, 525, axis=0)
    print(b[:5], b.shape)
    duration = 10.0
    fs = 1000.0
    samples = int(fs * duration)
    t = np.arange(samples) / fs
    train_set_size = int(big_data.shape[0] / 2)
    valid_set_size = int(train_set_size + big_data.shape[0] * 0.2)
    #print('set_size',train_set_size,valid_set_size)
    train_set_signal = np.array((hht_set_signal[:train_set_size]),
                                dtype=np.float)  #0.5
    valid_set_signal = np.array(
        (hht_set_signal[train_set_size:valid_set_size]), dtype=np.float)  #0.2
    test_set_signal = np.array((hht_set_signal[valid_set_size:]),
                               dtype=np.float)  #0.3

    #train_set_label=np.array((big_data[:298,1:2]), dtype=np.float)
    #valid_set_label=np.array((big_data[298:418,1:2]), dtype=np.float)
    #test_set_label=np.array((big_data[418:,1:2]), dtype=np.float)

    train_set_label = np.array((b[:train_set_size]), dtype=np.float)
    valid_set_label = np.array((b[train_set_size:valid_set_size]),
                               dtype=np.float)
    test_set_label = np.array((b[valid_set_size:]), dtype=np.float)
    print('test_set_label', np.sum(test_set_label), test_set_label.shape,
          test_set_signal.shape)
    #print('*',train_set_label[:5],train_set_label.shape)
    train_set_label = onehot(train_set_label)
    #print(train_set_label[:5],train_set_label.shape)
    valid_set_label = onehot(valid_set_label)
    test_set_label = onehot(test_set_label)

    #train_set_signal=train_set_signal[:,:,np.newaxis]
    #valid_set_signal=valid_set_signal[:,:,np.newaxis]
    #test_set_signal=test_set_signal[:,:,np.newaxis]
    '''
    train_set = combine(train_set_signal, train_set_label)
    valid_set = combine(valid_set_signal, valid_set_label)
    test_set = combine(test_set_signal, test_set_label)
    '''
    train_set = [np.array(train_set_signal), np.array(train_set_label)]
    valid_set = [np.array(valid_set_signal), np.array(valid_set_label)]
    test_set = [np.array(test_set_signal), np.array(test_set_label)]

    train_set[0], means_and_stds = normalize(train_set[0])
    valid_set[0], _ = normalize(valid_set[0], means_and_stds)
    test_set[0], _ = normalize(test_set[0], means_and_stds)

    #print(test_set_label[0])
    #print(valid_set[0][0][:5])
    print('have', train_set[1][:, 0].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 0].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 0].sum() / test_set[1][:, 0].shape[0])
    print('have', train_set[1][:, 1].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 1].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 1].sum() / test_set[1][:, 0].shape[0])
    print('train_set[0]', train_set[0].shape)
    print('valid_set[0]', valid_set[0].shape)
    print('test_set[0]', test_set[0].shape)
    # save means and stds to wandb
    #with open(os.path.join(wandb.run.dir, 'means_and_stds.pl'), 'wb') as f:
    #pickle.dump(g.means_and_stds, f)

    model = backbone(wandb.config,
                     include_top=True,
                     classification=True,
                     classes=2)
    model.compile(RAdam(1e-4) if wandb.config.radam else Adam(amsgrad=True),
                  'binary_crossentropy',
                  metrics=['acc'])
    model.summary()
    wandb.log({'model_params': model.count_params()}, commit=False)

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=50),
        # ReduceLROnPlateau(patience=10, cooldown=5, verbose=1),
        LogBest(),
        WandbCallback(log_gradients=False, training_data=train_set),
    ]

    model.fit(train_set[0],
              train_set[1],
              batch_size=64,
              epochs=800,
              validation_data=(valid_set[0], valid_set[1]),
              callbacks=callbacks,
              shuffle=True)
    model.save(os.path.join(wandb.run.dir, 'final_model.h5'))

    # load best model from wandb and evaluate
    print('Evaluate the BEST model!')

    from tensorflow.keras.models import load_model
    from ekg.layers import LeftCropLike, CenterCropLike
    from ekg.layers.sincnet import SincConv1D

    custom_objects = {
        'SincConv1D': SincConv1D,
        'LeftCropLike': LeftCropLike,
        'CenterCropLike': CenterCropLike
    }

    model = load_model(os.path.join(wandb.run.dir, 'model-best.h5'),
                       custom_objects=custom_objects,
                       compile=False)

    evaluation.evaluation(model, test_set)
Example #2
0
def train():
    dataloaders = list()
    if 'big_exam' in wandb.config.datasets:
        dataloaders.append(AbnormalBigExamLoader(wandb_config=wandb.config))
    if 'audicor_10s' in wandb.config.datasets:
        dataloaders.append(AbnormalAudicor10sLoader(wandb_config=wandb.config))
    '''
    g = BaseDataGenerator(dataloaders=dataloaders,
                            wandb_config=wandb.config,
                            preprocessing_fn=preprocessing)
    train_set, valid_set, test_set = g.get()
    '''

    audicor_hs_s3s4 = np.load(
        './abnormal_detection/audicor_10s/audicor_hs_s3s4.npy')

    #hht_set_signal=np.transpose(hht_set_signal,(1,2,0))
    #hht_set_signal=np.vstack((hht_set_signal,amp))
    print('audicor_hs_s3s4', audicor_hs_s3s4.shape)

    a = audicor_hs_s3s4[:, 3:5]
    audicor_hs_s3s4 = audicor_hs_s3s4[:, 5:]
    b = np.zeros((audicor_hs_s3s4.shape[0], 1))
    for i in range(audicor_hs_s3s4.shape[0]):
        if (a[i, 0] == '1.0'):
            b[i] = 1
        if (a[i, 1] == '1.0'):
            b[i] = 1

    print(b[:5], b.shape, np.sum(b))
    duration = 10.0
    fs = 1000.0
    samples = int(fs * duration)
    t = np.arange(samples) / fs
    train_set_size = int(audicor_hs_s3s4.shape[0] / 2)
    valid_set_size = int(train_set_size + audicor_hs_s3s4.shape[0] * 0.2)
    print('set_size', train_set_size, valid_set_size)
    train_set_signal = np.array((audicor_hs_s3s4[:train_set_size]),
                                dtype=np.float)  #0.5
    valid_set_signal = np.array(
        (audicor_hs_s3s4[train_set_size:valid_set_size]), dtype=np.float)  #0.2
    test_set_signal = np.array((audicor_hs_s3s4[valid_set_size:]),
                               dtype=np.float)  #0.3

    train_set_signal = train_set_signal[:, :, np.newaxis]
    valid_set_signal = valid_set_signal[:, :, np.newaxis]
    test_set_signal = test_set_signal[:, :, np.newaxis]

    #train_set_label=np.array((big_data[:298,1:2]), dtype=np.float)
    #valid_set_label=np.array((big_data[298:418,1:2]), dtype=np.float)
    #test_set_label=np.array((big_data[418:,1:2]), dtype=np.float)

    train_set_label = np.array((b[:train_set_size]), dtype=np.float)
    valid_set_label = np.array((b[train_set_size:valid_set_size]),
                               dtype=np.float)
    test_set_label = np.array((b[valid_set_size:]), dtype=np.float)
    print('test_set_label', np.sum(test_set_label), test_set_label.shape,
          test_set_signal.shape)
    train_set_label = onehot(train_set_label)
    valid_set_label = onehot(valid_set_label)
    test_set_label = onehot(test_set_label)

    train_set = [np.array(train_set_signal), np.array(train_set_label)]
    valid_set = [np.array(valid_set_signal), np.array(valid_set_label)]
    test_set = [np.array(test_set_signal), np.array(test_set_label)]

    train_set[0], means_and_stds = normalize(train_set[0])
    valid_set[0], _ = normalize(valid_set[0], means_and_stds)
    test_set[0], _ = normalize(test_set[0], means_and_stds)

    #train_set[0]    = multi_input_format(train_set[0], wandb.config.include_info)
    #valid_set[0]    = multi_input_format(valid_set[0], wandb.config.include_info)
    #test_set[0]     = multi_input_format(test_set[0], wandb.config.include_info)

    print('have', train_set[1][:, 0].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 0].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 0].sum() / test_set[1][:, 0].shape[0])
    print('have', train_set[1][:, 1].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 1].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 1].sum() / test_set[1][:, 0].shape[0])
    print(train_set[1][:, 0].shape[0], valid_set[1][:, 0].shape[0],
          test_set[1][:, 0].shape[0])
    '''for X in [train_set[0], valid_set[0], test_set[0]]:
        if wandb.config.n_ekg_channels != 0:
            X['ekg_input'] = X['ekg_hs_input'][..., :wandb.config.n_ekg_channels]
        if wandb.config.n_hs_channels != 0:
            print('X',X['ekg_hs_input'].shape)
            hs = X['ekg_hs_input'][..., -wandb.config.n_hs_channels:] # (?, n_samples, n_channels)
            X['hs_input'] = mp_generate_wavelet(hs,wandb.config.sampling_rate, 
                                                            wandb.config.wavelet_scale_length,
                                                            'Generate Wavelets')
        X.pop('ekg_hs_input')
    '''
    # save means and stds to wandb
    #with open(os.path.join(wandb.run.dir, 'means_and_stds.pl'), 'wb') as f:
    #pickle.dump(g.means_and_stds, f)

    model = backbone(wandb.config,
                     include_top=True,
                     classification=True,
                     classes=2)
    model.compile(RAdam(1e-4) if wandb.config.radam else Adam(amsgrad=True),
                  'binary_crossentropy',
                  metrics=['acc'])
    model.summary()
    wandb.log({'model_params': model.count_params()}, commit=False)

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=50),
        # ReduceLROnPlateau(patience=10, cooldown=5, verbose=1),
        LogBest(),
        WandbCallback(log_gradients=False, training_data=train_set),
    ]

    model.fit(train_set[0],
              train_set[1],
              batch_size=64,
              epochs=800,
              validation_data=(valid_set[0], valid_set[1]),
              callbacks=callbacks,
              shuffle=True)
    model.save(os.path.join(wandb.run.dir, 'final_model.h5'))

    # load best model from wandb and evaluate
    print('Evaluate the BEST model!')

    from tensorflow.keras.models import load_model
    from ekg.layers import LeftCropLike, CenterCropLike
    from ekg.layers.sincnet import SincConv1D

    custom_objects = {
        'SincConv1D': SincConv1D,
        'LeftCropLike': LeftCropLike,
        'CenterCropLike': CenterCropLike
    }

    model = load_model(os.path.join(wandb.run.dir, 'model-best.h5'),
                       custom_objects=custom_objects,
                       compile=False)

    evaluation.evaluation(model, test_set)
Example #3
0
def train():
    dataloaders = list()
    if 'big_exam' in wandb.config.datasets:
        dataloaders.append(AbnormalBigExamLoader(wandb_config=wandb.config))
    if 'audicor_10s' in wandb.config.datasets:
        dataloaders.append(AbnormalAudicor10sLoader(wandb_config=wandb.config))
    '''
    g = BaseDataGenerator(dataloaders=dataloaders,
                            wandb_config=wandb.config,
                            preprocessing_fn=preprocessing)
    train_set, valid_set, test_set = g.get()
    print(len(train_set),train_set[0].shape,train_set[1].shape)
    print_statistics(train_set, valid_set, test_set)
    '''
    big_data = np.load('big_data_signal_label.npy')
    #print('big_data',big_data.shape,big_data[:5])
    a = big_data[:, 1:3]
    b = np.zeros((595, 1))
    for i in range(595):

        if (a[i, 0] == '1'):
            b[i] = 1
        if (a[i, 1] == '1'):
            b[i] = 1
    #print(a[:5,0],a[:5,1])
    #print(a[:5],a.shape)
    #print(b[:5],b.shape)
    train_set_signal = np.array((big_data[:298, 3:]), dtype=np.float)
    valid_set_signal = np.array((big_data[298:418, 3:]), dtype=np.float)
    test_set_signal = np.array((big_data[418:, 3:]), dtype=np.float)
    #train_set_label=np.array((big_data[:298,1:2]), dtype=np.float)
    #valid_set_label=np.array((big_data[298:418,1:2]), dtype=np.float)
    #test_set_label=np.array((big_data[418:,1:2]), dtype=np.float)

    train_set_label = np.array((b[:298]), dtype=np.float)
    valid_set_label = np.array((b[298:418]), dtype=np.float)
    test_set_label = np.array((b[418:]), dtype=np.float)

    #print('*',train_set_label[:5],train_set_label.shape)
    train_set_label = onehot(train_set_label)
    #print(train_set_label[:5],train_set_label.shape)
    valid_set_label = onehot(valid_set_label)
    test_set_label = onehot(test_set_label)

    train_set_signal = train_set_signal[:, :, np.newaxis]
    valid_set_signal = valid_set_signal[:, :, np.newaxis]
    test_set_signal = test_set_signal[:, :, np.newaxis]
    '''
    train_set = combine(train_set_signal, train_set_label)
    valid_set = combine(valid_set_signal, valid_set_label)
    test_set = combine(test_set_signal, test_set_label)
    '''
    train_set = [np.array(train_set_signal), np.array(train_set_label)]
    valid_set = [np.array(valid_set_signal), np.array(valid_set_label)]
    test_set = [np.array(test_set_signal), np.array(test_set_label)]

    train_set[0], means_and_stds = normalize(train_set[0])
    valid_set[0], _ = normalize(valid_set[0], means_and_stds)
    test_set[0], _ = normalize(test_set[0], means_and_stds)
    plt.figure()
    plt.plot(range(10000), train_set[0][0])
    plt.show()

    #print(train_set[0][0][:5])
    #print(valid_set[0][0][:5])
    print('have', train_set[1][:, 0].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 0].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 0].sum() / test_set[1][:, 0].shape[0])
    print('have', train_set[1][:, 1].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 1].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 1].sum() / test_set[1][:, 0].shape[0])
    # save means and stds to wandb
    #with open(os.path.join(wandb.run.dir, 'means_and_stds.pl'), 'wb') as f:
    #pickle.dump(g.means_and_stds, f)

    model = backbone(wandb.config,
                     include_top=True,
                     classification=True,
                     classes=1)

    model.compile(RAdam(1e-4) if wandb.config.radam else Adam(amsgrad=True),
                  'binary_crossentropy',
                  metrics=['acc'])
    weights = class_weight.compute_class_weight('balanced',
                                                np.unique(train_set[1][:, 0]),
                                                train_set[1][:, 0])
    class_weight_dict = dict(enumerate(weights))
    print('weights', weights)
    #weights2 = {0:0.64 , 1:2.19}
    print(class_weight_dict)
    model.summary()
    wandb.log({'model_params': model.count_params()}, commit=False)

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=50),
        # ReduceLROnPlateau(patience=10, cooldown=5, verbose=1),
        LogBest(),
        WandbCallback(log_gradients=False, training_data=train_set),
    ]

    model.fit(train_set[0],
              train_set[1][:, 0],
              batch_size=64,
              class_weight=class_weight_dict,
              epochs=300,
              validation_data=(valid_set[0], valid_set[1][:, 0]),
              callbacks=callbacks,
              shuffle=True)  #valid_set[1]), callbacks=callbacks
    model.save(os.path.join(wandb.run.dir, 'final_model.h5'))

    # load best model from wandb and evaluate
    print('Evaluate the BEST model!')

    from tensorflow.keras.models import load_model
    from ekg.layers import LeftCropLike, CenterCropLike
    from ekg.layers.sincnet import SincConv1D

    custom_objects = {
        'SincConv1D': SincConv1D,
        'LeftCropLike': LeftCropLike,
        'CenterCropLike': CenterCropLike
    }

    model = load_model('model-best.h5',
                       custom_objects=custom_objects,
                       compile=False)
    model.summary()

    evaluation.evaluation(model, train_set)
    evaluation.evaluation(model, valid_set)
    evaluation.evaluation(model, test_set)
Example #4
0
def train():
    dataloaders = list()
    if 'big_exam' in wandb.config.datasets:
        dataloaders.append(AbnormalBigExamLoader(wandb_config=wandb.config))
    if 'audicor_10s' in wandb.config.datasets:
        dataloaders.append(AbnormalAudicor10sLoader(wandb_config=wandb.config))

    g = AbnormalDataGenerator(dataloaders=dataloaders,
                              wandb_config=wandb.config,
                              preprocessing_fn=preprocessing)
    train_set, valid_set, test_set = g.get()
    print_statistics(train_set, valid_set, test_set)

    # save means and stds to wandb
    with open(os.path.join(wandb.run.dir, 'means_and_stds.pl'), 'wb') as f:
        pickle.dump(g.means_and_stds, f)

    model = backbone(wandb.config,
                     include_top=True,
                     classification=True,
                     classes=2)
    model.compile(RAdam(1e-4) if wandb.config.radam else Adam(amsgrad=True),
                  'binary_crossentropy',
                  metrics=['acc'])
    model.summary()
    wandb.log({'model_params': model.count_params()}, commit=False)

    callbacks = [
        # ReduceLROnPlateau(patience=10, cooldown=5, verbose=1),
        LogBest(),
        WandbCallback(),
        GarbageCollector(),
        EarlyStopping(monitor='val_loss', patience=50),
    ]

    model.fit(train_set[0],
              train_set[1],
              batch_size=wandb.config.batch_size,
              epochs=200,
              validation_data=(valid_set[0], valid_set[1]),
              callbacks=callbacks,
              shuffle=True)
    model.save(os.path.join(wandb.run.dir, 'final_model.h5'))

    # load best model from wandb and evaluate
    print('Evaluate the BEST model!')

    from tensorflow.keras.models import load_model
    from ekg.layers import LeftCropLike, CenterCropLike
    from ekg.layers.sincnet import SincConv1D

    custom_objects = {
        'SincConv1D': SincConv1D,
        'LeftCropLike': LeftCropLike,
        'CenterCropLike': CenterCropLike
    }

    model = load_model(os.path.join(wandb.run.dir, 'model-best.h5'),
                       custom_objects=custom_objects,
                       compile=False)

    evaluation.evaluation([model], test_set)
Example #5
0
def train():
    dataloaders = list()
    if 'big_exam' in wandb.config.datasets:
        dataloaders.append(HazardBigExamLoader(wandb_config=wandb.config))
    if 'audicor_10s' in wandb.config.datasets:
        dataloaders.append(HazardAudicor10sLoader(wandb_config=wandb.config))

    g = HazardDataGenerator(dataloaders=dataloaders,
                            wandb_config=wandb.config,
                            preprocessing_fn=preprocessing)

    train_set, valid_set, test_set = g.get()
    print_statistics(train_set, valid_set, test_set, wandb.config.events)

    # save means and stds to wandb
    with open(os.path.join(wandb.run.dir, 'means_and_stds.pl'), 'wb') as f:
        pickle.dump(g.means_and_stds, f)

    if wandb.config.include_info and wandb.config.info_apply_noise:
        wandb.config.info_norm_noise_std = np.array(
            wandb.config.info_noise_stds) / np.array(g.means_and_stds[1][1])

    prediction_model = backbone(wandb.config,
                                include_top=True,
                                classification=False,
                                classes=len(wandb.config.events))
    trainable_model = get_trainable_model(prediction_model,
                                          get_loss_layer(wandb.config.loss))

    trainable_model.compile(
        RAdam(1e-4) if wandb.config.radam else Adam(amsgrad=True), loss=None)
    trainable_model.summary()
    wandb.log({'model_params': trainable_model.count_params()}, commit=False)

    c_index_reverse, scatter_exp = (wandb.config.loss !=
                                    'AFT'), (wandb.config.loss == 'AFT')
    scatter_xlabel = 'predicted survival time (days)' if wandb.config.loss == 'AFT' else 'predicted risk'

    callbacks = [
        # ReduceLROnPlateau(patience=10, cooldown=5, verbose=1),
        LossVariableChecker(wandb.config.events),
        ConcordanceIndex(train_set,
                         valid_set,
                         wandb.config.events,
                         prediction_model,
                         reverse=c_index_reverse),
        LogBest(records=['val_loss', 'loss'] + [
            '{}_cindex'.format(event_name)
            for event_name in wandb.config.events
        ] + [
            'val_{}_cindex'.format(event_name)
            for event_name in wandb.config.events
        ] + [
            '{}_sigma'.format(event_name) for event_name in wandb.config.events
        ]),
        WandbCallback(),
        EarlyStopping(
            monitor='val_loss',
            patience=50),  # must be placed last otherwise it won't work
    ]

    X_train, y_train, X_valid, y_valid = to_trainable_X(
        train_set), None, to_trainable_X(valid_set), None
    dataset_shuffle(X_train)
    dataset_shuffle(X_valid)

    trainable_model.fit(X_train,
                        y_train,
                        batch_size=wandb.config.batch_size,
                        epochs=1000,
                        validation_data=(X_valid, y_valid),
                        callbacks=callbacks,
                        shuffle=True)
    trainable_model.save(os.path.join(wandb.run.dir, 'final_model.h5'))

    # load best model from wandb and evaluate
    print('Evaluate the BEST model!')

    custom_objects = {
        'SincConv1D': SincConv1D,
        'LeftCropLike': LeftCropLike,
        'CenterCropLike': CenterCropLike,
        'AFTLoss': AFTLoss,
        'CoxLoss': CoxLoss,
    }

    model = load_model(os.path.join(wandb.run.dir, 'model-best.h5'),
                       custom_objects=custom_objects,
                       compile=False)
    prediction_model = to_prediction_model(model, wandb.config.include_info)

    print('Training set:')
    evaluation(prediction_model,
               train_set,
               wandb.config.events,
               reverse=c_index_reverse)

    print('Testing set:')
    evaluation(prediction_model,
               test_set,
               wandb.config.events,
               reverse=c_index_reverse)

    evaluation_plot(prediction_model,
                    train_set,
                    train_set,
                    'training - ',
                    reverse=c_index_reverse,
                    scatter_exp=scatter_exp,
                    scatter_xlabel=scatter_xlabel)
    evaluation_plot(prediction_model,
                    train_set,
                    valid_set,
                    'validation - ',
                    reverse=c_index_reverse,
                    scatter_exp=scatter_exp,
                    scatter_xlabel=scatter_xlabel)
    evaluation_plot(prediction_model,
                    train_set,
                    test_set,
                    'testing - ',
                    reverse=c_index_reverse,
                    scatter_exp=scatter_exp,
                    scatter_xlabel=scatter_xlabel)
def train():
    dataloaders = list()
    if 'big_exam' in wandb.config.datasets:
        dataloaders.append(AbnormalBigExamLoader(wandb_config=wandb.config))
    if 'audicor_10s' in wandb.config.datasets:
        dataloaders.append(AbnormalAudicor10sLoader(wandb_config=wandb.config))

    g = BaseDataGenerator(dataloaders=dataloaders,
                          wandb_config=wandb.config,
                          preprocessing_fn=preprocessing)
    train_set, valid_set, test_set = g.get()
    big_data = np.load('big_data_signal_label.npy')

    #print(a[:5,0],a[:5,1])
    #print(a[:5],a.shape)
    #print(b[:5],b.shape)
    duration = 10.0
    fs = 1000.0
    samples = int(fs * duration)
    t = np.arange(samples) / fs
    train_set_signal = np.array((big_data[:298, 3:]), dtype=np.float)
    valid_set_signal = np.array((big_data[298:418, 3:]), dtype=np.float)
    test_set_signal = np.array((big_data[418:, 3:]), dtype=np.float)

    train_set_label = np.array((big_data[:298, 1:2]), dtype=np.float)
    valid_set_label = np.array((big_data[298:418, 1:2]), dtype=np.float)
    test_set_label = np.array((big_data[418:, 1:2]), dtype=np.float)

    print('label', np.sum(train_set_label), np.sum(valid_set_label),
          np.sum(test_set_label))

    train_set_label = onehot(train_set_label)
    valid_set_label = onehot(valid_set_label)
    test_set_label = onehot(test_set_label)

    train_set_signal = train_set_signal[:, :, np.newaxis]
    valid_set_signal = valid_set_signal[:, :, np.newaxis]
    test_set_signal = test_set_signal[:, :, np.newaxis]
    '''
    train_set = combine(train_set_signal, train_set_label)
    valid_set = combine(valid_set_signal, valid_set_label)
    test_set = combine(test_set_signal, test_set_label)
    '''
    train_set = [np.array(train_set_signal), np.array(train_set_label)]
    valid_set = [np.array(valid_set_signal), np.array(valid_set_label)]
    test_set = [np.array(test_set_signal), np.array(test_set_label)]

    train_set[0], means_and_stds = normalize(train_set[0])
    valid_set[0], _ = normalize(valid_set[0], means_and_stds)
    test_set[0], _ = normalize(test_set[0], means_and_stds)

    #print(test_set_label[0])
    #print(valid_set[0][0][:5])
    print('have', train_set[1][:, 0].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 0].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 0].sum() / test_set[1][:, 0].shape[0])
    print('have', train_set[1][:, 1].sum() / train_set[1][:, 0].shape[0],
          valid_set[1][:, 1].sum() / valid_set[1][:, 0].shape[0],
          test_set[1][:, 1].sum() / test_set[1][:, 0].shape[0])

    # save means and stds to wandb
    #with open(os.path.join(wandb.run.dir, 'means_and_stds.pl'), 'wb') as f:
    #pickle.dump(g.means_and_stds, f)

    model = backbone(wandb.config,
                     include_top=True,
                     classification=True,
                     classes=2)
    model.compile(RAdam(1e-4) if wandb.config.radam else Adam(amsgrad=True),
                  'binary_crossentropy',
                  metrics=['acc'])
    model.summary()
    wandb.log({'model_params': model.count_params()}, commit=False)

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=50),
        # ReduceLROnPlateau(patience=10, cooldown=5, verbose=1),
        LogBest(),
        WandbCallback(log_gradients=False, training_data=train_set),
    ]

    model.fit(train_set[0],
              train_set[1],
              batch_size=64,
              epochs=700,
              validation_data=(valid_set[0], valid_set[1]),
              callbacks=callbacks,
              shuffle=True)
    model.save(os.path.join(wandb.run.dir, 'final_model.h5'))

    # load best model from wandb and evaluate
    print('Evaluate the BEST model!')

    from tensorflow.keras.models import load_model
    from ekg.layers import LeftCropLike, CenterCropLike
    from ekg.layers.sincnet import SincConv1D

    custom_objects = {
        'SincConv1D': SincConv1D,
        'LeftCropLike': LeftCropLike,
        'CenterCropLike': CenterCropLike
    }

    model = load_model(os.path.join(wandb.run.dir, 'model-best.h5'),
                       custom_objects=custom_objects,
                       compile=False)

    evaluation.evaluation(model, test_set)