def run_model_event_range_generator(model_name, participator, timesteps, stride, nb_epoch, event_range, load_weight_from = None): logger = logging.getLogger() f_time = datetime.datetime.today() output_dir = os.path.join('output', str(f_time)) if not os.path.exists(output_dir): os.makedirs(output_dir) hdlr = logging.FileHandler(os.path.join(output_dir, 'rnn.log')) logger.addHandler(hdlr) console_handler = logging.StreamHandler() logger.addHandler(console_handler) logger.setLevel(logging.INFO) gal = GAL_data() gal.set_logger(logger) gal.load_data() data_description = gal.get_data_description() participator = participator logger.info('participator : {0}'.format(participator)) event_list = ['tHandStart', 'tFirstDigitTouch', 'tBothStartLoadPhase', 'tLiftOff', 'tReplace', 'tBothReleased', 'tHandStop'] rnn = EEG_model(event_list) rnn.set_logger(logger) rnn.select_model(model_name) if load_weight_from != None: rnn.load_model_weight(model_name, load_weight_from) logger.info( 'running model data from a generator') data_len=gal.part_data_count[participator] data_split_ratio = [0.8, 0.1, 0.1] train_list = np.arange(int(data_len * data_split_ratio[0])) validate_list = np.arange(int(data_len * data_split_ratio[1])) test_list = np.arange(data_len - int(data_len * data_split_ratio[0]) - int(data_len * data_split_ratio[1])) for epoch in range(nb_epoch): generator = gal.X_y_part_generator(part=participator, timesteps=timesteps, stride=stride, event_list=event_list, event_range=event_range) logger.info( 'epoch : {0}'.format(epoch)) start = time.clock() rnn.run_model_with_generator_event(generator=generator, train_list=train_list, validate_list=validate_list, test_list=test_list) logger.info( 'epoch {0} ran for {1} minutes'.format(epoch, (time.clock() - start)/60)) rnn.set_data_description(data_description) rnn.set_model_config('epoch', nb_epoch) generator = gal.X_y_part_generator(part=participator, timesteps=timesteps, stride=stride, event_list=event_list, event_range=event_range) rnn.save_event(generator=generator,train_list=train_list, validate_list=validate_list,test_list=test_list, event_list=event_list, output_dir=output_dir)
def run_model_duration(model_name, participator, timesteps, stride, nb_epoch, load_weight_from = None): logger_name = model_name + str(participator) + str(timesteps) + str(stride) + str(nb_epoch) + str(load_weight_from) logger = logging.getLogger(logger_name) # so that no multiple loggers input the same data f_time = datetime.datetime.today() output_dir = os.path.join('output', 'dur_'+str(f_time)) if not os.path.exists(output_dir): os.makedirs(output_dir) hdlr = logging.FileHandler(os.path.join(output_dir, 'rnn.log')) logger.addHandler(hdlr) console_handler = logging.StreamHandler() logger.addHandler(console_handler) logger.setLevel(logging.INFO) gal = GAL_data() gal.set_logger(logger) gal.load_data(load_list=['eeg', 'info']) data_description = gal.get_data_description() participator = participator logger.info('participator : {0}'.format(participator)) #event_list=['Dur_Reach', 'Dur_Preload', 'Dur_LoadPhase', 'Dur_Release', 'Dur_Retract'] event_list=['Dur_Reach', 'Dur_LoadReach', 'Dur_LoadMaintain', 'Dur_LoadRetract', 'Dur_Retract'] rnn = EEG_model(event_list) rnn.set_logger(logger) rnn.select_model(model_name) if load_weight_from != None: rnn.load_model_weight(model_name, load_weight_from) logger.info( 'running model data as a whole') data_split_ratio = [0.8, 0.1, 0.1] data = gal.data_event(part=participator, timesteps=timesteps, stride=stride, event_list=event_list, partition_ratio=data_split_ratio, input_dim=32) loss_train, loss_val, loss_test = rnn.run_model_event(data=data, nb_epoch = nb_epoch) with open(os.path.join(output_dir, 'train_loss.json'), 'w') as f: json.dump(loss_train, f) with open(os.path.join(output_dir, 'validate_loss.json'), 'w') as f: json.dump(loss_val, f) with open(os.path.join(output_dir, 'test_loss.json'), 'w') as f: json.dump(loss_test, f) rnn.set_data_description(data_description) rnn.set_model_config('epoch', nb_epoch) generator = gal.data_generator_event(part=participator, timesteps=timesteps, stride=stride, event_list=event_list) rnn.save_event(data=data, event_list=event_list, output_dir=output_dir)