예제 #1
0
def train(args, tvt_counts=None):
    print(args)

    activations = ops.lrelu if args.activations=="lrelu" else tf.nn.relu

    nb_patients = None if args.nb_patients==-1 else args.nb_patients
    if nb_patients is not None:
        assert(args.n_patient_queues <= int(0.2 * nb_patients)), "Too many patient queues for the number of patient used"

    # if balancing classes samples, do not balance cost
    assert not (args.balance_classes and args.balance_cost), "It is probably a bad idea to both balance training samples over classes and weight the cost. "

    ########## prepare data ##############
    print("PREPARING DATA...")
    shhs.prepare_dataset(shhs_base_dir, nb_patients=nb_patients,
                        filter=args.filter, channel=args.channel)
    fil = 'filtered' if args.filter else 'nofilter'
    

    ######## prepare experiment ############
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    exp_name = get_exp_name(
        args.nb_patients, 
        args.batch_size, args.training_batches, args.learning_rates, 
        args.conv_type, args.batch_norm, 
        args.activations, args.featuremap_sizes, args.strides, 
        args.filter_sizes, args.hiddenlayer_size, 
        args.eps_before, args.eps_after, 
        args.channel, args.balance_sm, 
        args.balance_cost, args.balance_classes, 
        args.use_if_missing_stage, args.filter)

    exp_dir = os.path.join(results_dir, exp_name)
    exp_checkpoint_dir = os.path.join(checkpoint_dir, exp_name)
    exp_checkpoint_dir_best = os.path.join(checkpoint_dir, exp_name, 'best')
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    if not os.path.exists(exp_checkpoint_dir):
        os.makedirs(exp_checkpoint_dir)
    if not os.path.exists(exp_checkpoint_dir_best):
        os.makedirs(exp_checkpoint_dir_best)

    comment_text = ("Training batches: " + str(args.training_batches) + '\n' +
                    "Batch size: " + str(args.batch_size) + '\n' +
                    "Learning rates: " + str(args.learning_rates) + '\n' +
                    "Featuremap_sizes: " + str(args.featuremap_sizes) + '\n' +
                    "Filter sizes: " + str(args.filter_sizes) + '\n' +
                    "Strides: " + str(args.strides) + '\n' +
                    "Hidden layer size: " + str(args.hiddenlayer_size) + '\n' +
                    "Conv type: " + str(args.conv_type) + '\n' +
                    "Batch norm: " + str(args.batch_norm) + '\n' +
                    "Activations: " + str(args.activations) + '\n' +
                    "Number of epochs before: " + str(args.eps_before) + '\n' +
                    "Number of epochs after: " + str(args.eps_after) + '\n' +
                    "Channel used: " + args.channel + '\n' + 
                    "Balance softmax: " + str(args.balance_sm) + '\n' +
                    "Balance cost: " + str(args.balance_cost) + '\n' +
                    "Balance classes: " + str(args.balance_classes) + '\n' +
                    "Use_if_missing_stage: " + str(args.use_if_missing_stage) + '\n' +
                    "Filter: " + str(args.filter) + '\n' +
                    "Number of patient queues: " + str(args.n_patient_queues) + '\n')

    write_to_comment_file(os.path.join(exp_dir, "comment.txt"), comment_text)

    ########## prepare train, valid, test sets ##############
    #load previously saved names if a saved model exists
    if tf.train.get_checkpoint_state(exp_checkpoint_dir_best) is not None:
        print("loading train/valid/test split from saved model...")
        train_file = os.path.join(exp_dir, 'names_train.txt')
        valid_file = os.path.join(exp_dir, 'names_valid.txt')
        test_file = os.path.join(exp_dir, 'names_test.txt')

        names_train = get_lines_list(train_file)
        names_train = [os.path.join(shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, nm) for nm in names_train]
        names_valid = get_lines_list(valid_file)
        names_valid = [os.path.join(shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, nm) for nm in names_valid]
        names_test = get_lines_list(test_file)
        names_test = [os.path.join(shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, nm) for nm in names_test]

    else:
        preprocessed_names = glob.glob(os.path.join(
            shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, '*.p'))
        preprocessed_names = preprocessed_names[:nb_patients]
        # shuffle
        r = np.arange(len(preprocessed_names))
        np.random.shuffle(r)
        preprocessed_names = [preprocessed_names[i] for i in r]

        tvt_proportions = (0.5, 0.2, 0.3)
        n_train = int(tvt_proportions[0]*len(preprocessed_names))
        print('n_train: ', n_train)
        n_valid = int(tvt_proportions[1]*len(preprocessed_names))
        print('n_valid: ', n_valid)
        names_train = preprocessed_names[0:n_train]
        names_valid = preprocessed_names[n_train:n_train+n_valid]
        names_test = preprocessed_names[n_train+n_valid:]

        write_name_list_to_file(os.path.join(exp_dir, "names_train.txt"), names_train)
        write_name_list_to_file(os.path.join(exp_dir, "names_valid.txt"), names_valid)
        write_name_list_to_file(os.path.join(exp_dir, "names_test.txt"), names_test)


    
    ######### define model ##############
    print("DEFINING MODEL...")

    model = CNN(batch_size=args.batch_size,
                featuremap_sizes=args.featuremap_sizes,
                strides=args.strides,
                filter_sizes=args.filter_sizes,
                hiddenlayer_size=args.hiddenlayer_size,
                balance_sm=args.balance_sm,
                balance_cost=args.balance_cost,
                conv_type=args.conv_type,
                batch_norm=args.batch_norm,
                activations=activations,
                eps_before=args.eps_before,
                eps_after=args.eps_after,
                filter=args.filter)

    model.init()

    #load previously saved model if there is one
    if tf.train.get_checkpoint_state(exp_checkpoint_dir_best) is not None:
        model.load_model(exp_checkpoint_dir_best)
        #model.load_model(exp_dir)
        # also load sequence of previous costs
        costs = np.load(os.path.join(exp_checkpoint_dir_best, 'costs.npy')).tolist()
        costs_valid = np.load(os.path.join(exp_checkpoint_dir_best, 'costs_valid.npy')).tolist()
        accs = np.load(os.path.join(exp_checkpoint_dir_best, 'accs.npy')).tolist()
        accs_valid = np.load(os.path.join(exp_checkpoint_dir_best, 'accs_valid.npy')).tolist()
        # TODO: also save (and load) current_best_cost and current_best_acc
    else:
        costs = []
        costs_valid = []
        accs = []
        accs_valid = []

    ########## TRAIN ##############################
    print('TRAINING...')

    which_val_metric = 'acc' #'cost' # whether to monitor acc or cost for best model

    try:
        #estimate how many batches an epoch is
        # print('iterating once over data to count the number of train batches for this split...')
        # names_iterator_train = shhs.patient_names_iterator(names_train)
        # it_1ep_train = shhs.data_iterator_1epoch(
        #     names_iterator_train, 
        #     args.n_patient_queues,
        #     epochs_before=args.eps_before,
        #     epochs_after=args.eps_after,
        #     balance_classes=args.balance_classes,
        #     use_if_missing_stage=args.use_if_missing_stage)
        # ex_per_ep = 0
        # try:
        #     while True:
        #         _, _ = next(it_1ep_train)
        #         ex_per_ep += 1
        # except StopIteration:
        #     pass
        # finally:
        #     del names_iterator_train, it_1ep_train
        # batches_per_ep = ex_per_ep // args.batch_size
        # print("Number of batches per epoch: ", batches_per_ep)

        # simply initialize iterators with a very large number
        # of batches to make sure they do not run out
        # the count will be made in the training loop. 
        it_train = shhs.data_iterator_fixedN(
            preprocessed_names=names_train, 
            n=int(1e8)*args.batch_size,
            n_patient_queues=args.n_patient_queues,
            epochs_before=args.eps_before, 
            epochs_after=args.eps_after, 
            balance_classes=args.balance_classes, 
            use_if_missing_stage=args.use_if_missing_stage)
        b_it_train = shhs.batches_iterator(it_train, args.batch_size)

        it_valid = shhs.data_iterator_fixedN(
            preprocessed_names=names_valid, 
            n=int(1e8)*args.batch_size,
            n_patient_queues=args.n_patient_queues,
            epochs_before=args.eps_before, 
            epochs_after=args.eps_after, 
            balance_classes=args.balance_classes, 
            use_if_missing_stage=args.use_if_missing_stage)
        b_it_valid = shhs.batches_iterator(it_valid, args.batch_size)

        # Training loop.
        # Validation cost is evaluated continuously, 
        # one batch every 5 train batches

        current_best_cost = 1e20
        current_best_acc = 0.

        for b in range(np.sum(args.training_batches)): #tqdm(range(np.sum(args.training_batches))):

            if b % 200 == 0:
                print(str(100*b/np.sum(args.training_batches)) + ' percent done...')
                sys.stdout.flush()

            # if b % (np.sum(args.training_batches)//10) == 0: # practical when working with server
            #     write_to_comment_file(os.path.join(exp_dir, 'STATUS.txt'), 
            #                           '%s percent done...' %(10*(b//(np.sum(args.training_batches)//10))))

            # For variable learning rate:
            c01 = b > np.cumsum(args.training_batches)
            idx_lr = np.where(c01==1)[0]
            if len(idx_lr) == 0:
                idx_lr = 0
            else:
                idx_lr = idx_lr[-1] + 1
            lr = args.learning_rates[idx_lr]


            if b % 10000==0 and b > 0:
                print('--- Intermediate validation... ---')
                cost, acc = test(args, model, names_valid, exp_dir, b, 
                                 current_best_cost, current_best_acc, 
                                 which_="valid", which_metric=which_val_metric)
                metric = cost if which_val_metric=='cost' else -acc
                metric_best = current_best_cost if which_val_metric=='cost' else -current_best_acc
                if metric <= metric_best:
                    model.save_model(os.path.join(exp_checkpoint_dir_best, 'model.ckpt'))
                    np.save(os.path.join(exp_checkpoint_dir_best, 'costs.npy'), np.array([*costs]))
                    np.save(os.path.join(exp_checkpoint_dir_best, 'accs.npy'), np.array([*accs]))
                    np.save(os.path.join(exp_checkpoint_dir_best, 'costs_valid.npy'), np.array([*costs_valid]))
                    np.save(os.path.join(exp_checkpoint_dir_best, 'accs_valid.npy'), np.array([*accs_valid]))
                    # test metrics. REM: another (faster) option is to reload only optimal model
                    # after training and test then...
                    test(args, model, names_test, exp_dir, b, 
                         1e20, 0., save=True, which_="test",
                         which_metric=which_val_metric)
                # update current bests
                if cost <= current_best_cost:
                    current_best_cost = cost
                if acc >= current_best_acc:
                    current_best_acc = acc

            # occasionally print stuff during training
            if b > 0 and b % 1000 == 0:
                print("average train cost over the last 500 evals: ",
                      np.mean(costs[-500:]))
                print("average valid cost over the last 100 evals: ",
                      np.mean(costs_valid[-100:]))
                print("average train acc over the last 100 evals: ", 
                      np.mean(accs[-500:]))
                print("average valid acc over the last 100 evals: ", 
                      np.mean(accs_valid[-100:]))
                print("learning rate: ", lr)

            examples, labels_target = next(b_it_train)
            numeric_in = {
                'inX': examples,
                'targetY': labels_target,
                'lr': lr,
            }
            acc_value, cost_value = model.train_model(numeric_in)
            costs += [cost_value]
            accs += [acc_value]
            
            if b % 5 == 0:
                examples, labels_target = next(b_it_valid)
                numeric_in = {
                    'inX': examples,
                    'targetY': labels_target,
                }
                _, acc_value, cost_value = model.estimate_model(numeric_in)
                costs_valid += [cost_value]
                accs_valid += [acc_value]

            if b % 5000 == 0:
                np.save(os.path.join(exp_dir, 'costs.npy'), np.array([*costs]))
                np.save(os.path.join(exp_dir, 'accs.npy'), np.array([*accs]))
                np.save(os.path.join(exp_dir, 'costs_valid.npy'), np.array([*costs_valid]))
                np.save(os.path.join(exp_dir, 'accs_valid.npy'), np.array([*accs_valid]))
                # TODO: also save (and load) current_best_cost and current_best_acc

        # also save everything after training
        model.save_model(os.path.join(exp_checkpoint_dir, 'model.ckpt'))
        np.save(os.path.join(exp_checkpoint_dir, 'costs.npy'), np.array([*costs]))
        np.save(os.path.join(exp_checkpoint_dir, 'accs.npy'), np.array([*accs]))
        np.save(os.path.join(exp_checkpoint_dir, 'costs_valid.npy'), np.array([*costs_valid]))
        np.save(os.path.join(exp_checkpoint_dir, 'accs_valid.npy'), np.array([*accs_valid]))
        
    except KeyboardInterrupt:
        print(' !!!!!!!! TRAINING INTERRUPTED !!!!!!!!')

    # finally, evaluate performance on test_set
    print('Evaluating performance on the test set...')
    _ = test(args, model, names_test, exp_dir, np.sum(args.training_batches), 
        current_best_cost)

    #write_to_comment_file(os.path.join(exp_dir, 'STATUS.txt'), 'FINISHED.')
    model.close()
    tf.reset_default_graph()
    print('###### DONE. ######')