def train(args, tvt_counts=None): print(args) activations = ops.lrelu if args.activations=="lrelu" else tf.nn.relu nb_patients = None if args.nb_patients==-1 else args.nb_patients if nb_patients is not None: assert(args.n_patient_queues <= int(0.2 * nb_patients)), "Too many patient queues for the number of patient used" # if balancing classes samples, do not balance cost assert not (args.balance_classes and args.balance_cost), "It is probably a bad idea to both balance training samples over classes and weight the cost. " ########## prepare data ############## print("PREPARING DATA...") shhs.prepare_dataset(shhs_base_dir, nb_patients=nb_patients, filter=args.filter, channel=args.channel) fil = 'filtered' if args.filter else 'nofilter' ######## prepare experiment ############ if not os.path.exists(results_dir): os.makedirs(results_dir) exp_name = get_exp_name( args.nb_patients, args.batch_size, args.training_batches, args.learning_rates, args.conv_type, args.batch_norm, args.activations, args.featuremap_sizes, args.strides, args.filter_sizes, args.hiddenlayer_size, args.eps_before, args.eps_after, args.channel, args.balance_sm, args.balance_cost, args.balance_classes, args.use_if_missing_stage, args.filter) exp_dir = os.path.join(results_dir, exp_name) exp_checkpoint_dir = os.path.join(checkpoint_dir, exp_name) exp_checkpoint_dir_best = os.path.join(checkpoint_dir, exp_name, 'best') if not os.path.exists(exp_dir): os.makedirs(exp_dir) if not os.path.exists(exp_checkpoint_dir): os.makedirs(exp_checkpoint_dir) if not os.path.exists(exp_checkpoint_dir_best): os.makedirs(exp_checkpoint_dir_best) comment_text = ("Training batches: " + str(args.training_batches) + '\n' + "Batch size: " + str(args.batch_size) + '\n' + "Learning rates: " + str(args.learning_rates) + '\n' + "Featuremap_sizes: " + str(args.featuremap_sizes) + '\n' + "Filter sizes: " + str(args.filter_sizes) + '\n' + "Strides: " + str(args.strides) + '\n' + "Hidden layer size: " + str(args.hiddenlayer_size) + '\n' + "Conv type: " + str(args.conv_type) + '\n' + "Batch norm: " + str(args.batch_norm) + '\n' + "Activations: " + str(args.activations) + '\n' + "Number of epochs before: " + str(args.eps_before) + '\n' + "Number of epochs after: " + str(args.eps_after) + '\n' + "Channel used: " + args.channel + '\n' + "Balance softmax: " + str(args.balance_sm) + '\n' + "Balance cost: " + str(args.balance_cost) + '\n' + "Balance classes: " + str(args.balance_classes) + '\n' + "Use_if_missing_stage: " + str(args.use_if_missing_stage) + '\n' + "Filter: " + str(args.filter) + '\n' + "Number of patient queues: " + str(args.n_patient_queues) + '\n') write_to_comment_file(os.path.join(exp_dir, "comment.txt"), comment_text) ########## prepare train, valid, test sets ############## #load previously saved names if a saved model exists if tf.train.get_checkpoint_state(exp_checkpoint_dir_best) is not None: print("loading train/valid/test split from saved model...") train_file = os.path.join(exp_dir, 'names_train.txt') valid_file = os.path.join(exp_dir, 'names_valid.txt') test_file = os.path.join(exp_dir, 'names_test.txt') names_train = get_lines_list(train_file) names_train = [os.path.join(shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, nm) for nm in names_train] names_valid = get_lines_list(valid_file) names_valid = [os.path.join(shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, nm) for nm in names_valid] names_test = get_lines_list(test_file) names_test = [os.path.join(shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, nm) for nm in names_test] else: preprocessed_names = glob.glob(os.path.join( shhs_base_dir, 'preprocessed', 'shhs1', fil, args.channel, '*.p')) preprocessed_names = preprocessed_names[:nb_patients] # shuffle r = np.arange(len(preprocessed_names)) np.random.shuffle(r) preprocessed_names = [preprocessed_names[i] for i in r] tvt_proportions = (0.5, 0.2, 0.3) n_train = int(tvt_proportions[0]*len(preprocessed_names)) print('n_train: ', n_train) n_valid = int(tvt_proportions[1]*len(preprocessed_names)) print('n_valid: ', n_valid) names_train = preprocessed_names[0:n_train] names_valid = preprocessed_names[n_train:n_train+n_valid] names_test = preprocessed_names[n_train+n_valid:] write_name_list_to_file(os.path.join(exp_dir, "names_train.txt"), names_train) write_name_list_to_file(os.path.join(exp_dir, "names_valid.txt"), names_valid) write_name_list_to_file(os.path.join(exp_dir, "names_test.txt"), names_test) ######### define model ############## print("DEFINING MODEL...") model = CNN(batch_size=args.batch_size, featuremap_sizes=args.featuremap_sizes, strides=args.strides, filter_sizes=args.filter_sizes, hiddenlayer_size=args.hiddenlayer_size, balance_sm=args.balance_sm, balance_cost=args.balance_cost, conv_type=args.conv_type, batch_norm=args.batch_norm, activations=activations, eps_before=args.eps_before, eps_after=args.eps_after, filter=args.filter) model.init() #load previously saved model if there is one if tf.train.get_checkpoint_state(exp_checkpoint_dir_best) is not None: model.load_model(exp_checkpoint_dir_best) #model.load_model(exp_dir) # also load sequence of previous costs costs = np.load(os.path.join(exp_checkpoint_dir_best, 'costs.npy')).tolist() costs_valid = np.load(os.path.join(exp_checkpoint_dir_best, 'costs_valid.npy')).tolist() accs = np.load(os.path.join(exp_checkpoint_dir_best, 'accs.npy')).tolist() accs_valid = np.load(os.path.join(exp_checkpoint_dir_best, 'accs_valid.npy')).tolist() # TODO: also save (and load) current_best_cost and current_best_acc else: costs = [] costs_valid = [] accs = [] accs_valid = [] ########## TRAIN ############################## print('TRAINING...') which_val_metric = 'acc' #'cost' # whether to monitor acc or cost for best model try: #estimate how many batches an epoch is # print('iterating once over data to count the number of train batches for this split...') # names_iterator_train = shhs.patient_names_iterator(names_train) # it_1ep_train = shhs.data_iterator_1epoch( # names_iterator_train, # args.n_patient_queues, # epochs_before=args.eps_before, # epochs_after=args.eps_after, # balance_classes=args.balance_classes, # use_if_missing_stage=args.use_if_missing_stage) # ex_per_ep = 0 # try: # while True: # _, _ = next(it_1ep_train) # ex_per_ep += 1 # except StopIteration: # pass # finally: # del names_iterator_train, it_1ep_train # batches_per_ep = ex_per_ep // args.batch_size # print("Number of batches per epoch: ", batches_per_ep) # simply initialize iterators with a very large number # of batches to make sure they do not run out # the count will be made in the training loop. it_train = shhs.data_iterator_fixedN( preprocessed_names=names_train, n=int(1e8)*args.batch_size, n_patient_queues=args.n_patient_queues, epochs_before=args.eps_before, epochs_after=args.eps_after, balance_classes=args.balance_classes, use_if_missing_stage=args.use_if_missing_stage) b_it_train = shhs.batches_iterator(it_train, args.batch_size) it_valid = shhs.data_iterator_fixedN( preprocessed_names=names_valid, n=int(1e8)*args.batch_size, n_patient_queues=args.n_patient_queues, epochs_before=args.eps_before, epochs_after=args.eps_after, balance_classes=args.balance_classes, use_if_missing_stage=args.use_if_missing_stage) b_it_valid = shhs.batches_iterator(it_valid, args.batch_size) # Training loop. # Validation cost is evaluated continuously, # one batch every 5 train batches current_best_cost = 1e20 current_best_acc = 0. for b in range(np.sum(args.training_batches)): #tqdm(range(np.sum(args.training_batches))): if b % 200 == 0: print(str(100*b/np.sum(args.training_batches)) + ' percent done...') sys.stdout.flush() # if b % (np.sum(args.training_batches)//10) == 0: # practical when working with server # write_to_comment_file(os.path.join(exp_dir, 'STATUS.txt'), # '%s percent done...' %(10*(b//(np.sum(args.training_batches)//10)))) # For variable learning rate: c01 = b > np.cumsum(args.training_batches) idx_lr = np.where(c01==1)[0] if len(idx_lr) == 0: idx_lr = 0 else: idx_lr = idx_lr[-1] + 1 lr = args.learning_rates[idx_lr] if b % 10000==0 and b > 0: print('--- Intermediate validation... ---') cost, acc = test(args, model, names_valid, exp_dir, b, current_best_cost, current_best_acc, which_="valid", which_metric=which_val_metric) metric = cost if which_val_metric=='cost' else -acc metric_best = current_best_cost if which_val_metric=='cost' else -current_best_acc if metric <= metric_best: model.save_model(os.path.join(exp_checkpoint_dir_best, 'model.ckpt')) np.save(os.path.join(exp_checkpoint_dir_best, 'costs.npy'), np.array([*costs])) np.save(os.path.join(exp_checkpoint_dir_best, 'accs.npy'), np.array([*accs])) np.save(os.path.join(exp_checkpoint_dir_best, 'costs_valid.npy'), np.array([*costs_valid])) np.save(os.path.join(exp_checkpoint_dir_best, 'accs_valid.npy'), np.array([*accs_valid])) # test metrics. REM: another (faster) option is to reload only optimal model # after training and test then... test(args, model, names_test, exp_dir, b, 1e20, 0., save=True, which_="test", which_metric=which_val_metric) # update current bests if cost <= current_best_cost: current_best_cost = cost if acc >= current_best_acc: current_best_acc = acc # occasionally print stuff during training if b > 0 and b % 1000 == 0: print("average train cost over the last 500 evals: ", np.mean(costs[-500:])) print("average valid cost over the last 100 evals: ", np.mean(costs_valid[-100:])) print("average train acc over the last 100 evals: ", np.mean(accs[-500:])) print("average valid acc over the last 100 evals: ", np.mean(accs_valid[-100:])) print("learning rate: ", lr) examples, labels_target = next(b_it_train) numeric_in = { 'inX': examples, 'targetY': labels_target, 'lr': lr, } acc_value, cost_value = model.train_model(numeric_in) costs += [cost_value] accs += [acc_value] if b % 5 == 0: examples, labels_target = next(b_it_valid) numeric_in = { 'inX': examples, 'targetY': labels_target, } _, acc_value, cost_value = model.estimate_model(numeric_in) costs_valid += [cost_value] accs_valid += [acc_value] if b % 5000 == 0: np.save(os.path.join(exp_dir, 'costs.npy'), np.array([*costs])) np.save(os.path.join(exp_dir, 'accs.npy'), np.array([*accs])) np.save(os.path.join(exp_dir, 'costs_valid.npy'), np.array([*costs_valid])) np.save(os.path.join(exp_dir, 'accs_valid.npy'), np.array([*accs_valid])) # TODO: also save (and load) current_best_cost and current_best_acc # also save everything after training model.save_model(os.path.join(exp_checkpoint_dir, 'model.ckpt')) np.save(os.path.join(exp_checkpoint_dir, 'costs.npy'), np.array([*costs])) np.save(os.path.join(exp_checkpoint_dir, 'accs.npy'), np.array([*accs])) np.save(os.path.join(exp_checkpoint_dir, 'costs_valid.npy'), np.array([*costs_valid])) np.save(os.path.join(exp_checkpoint_dir, 'accs_valid.npy'), np.array([*accs_valid])) except KeyboardInterrupt: print(' !!!!!!!! TRAINING INTERRUPTED !!!!!!!!') # finally, evaluate performance on test_set print('Evaluating performance on the test set...') _ = test(args, model, names_test, exp_dir, np.sum(args.training_batches), current_best_cost) #write_to_comment_file(os.path.join(exp_dir, 'STATUS.txt'), 'FINISHED.') model.close() tf.reset_default_graph() print('###### DONE. ######')