def train_proc(network, **kwargs): log_path = os.path.join(save_dir, network.checkpoint_name() + '.log') real_stdout = sys.stdout sys.stdout = open(log_path, 'w') convnet = None try: from convnet import ConvNet np.random.seed(network.seed) op, load_dic = network.get_op(**kwargs) convnet = ConvNet(op, load_dic) convnet.train() return True except RuntimeError: print(traceback.format_exc()) if convnet: print("\nerrored at epoch %d" % (convnet.epoch)) except: print(traceback.format_exc()) finally: if convnet: convnet.destroy_model_lib() reset_std('out', real_stdout)
def objective(layer_file_name, param_file_name, save_file_name): def logprob_errors(error_output): error_types, n = error_output logprob = error_types['logprob'][0] / n classifier = error_types['logprob'][1] / n logprob = np.inf if np.isnan(logprob) else logprob classifier = np.inf if np.isnan(classifier) else classifier return logprob, classifier real_stdout = sys.stdout sys.stdout = open(save_file_name + '.log', 'w') convnet = None try: # set up options op = ConvNet.get_options_parser() for option in op.get_options_list(): option.set_default() op.set_value('data_path', os.path.expanduser('~/data/cifar-10-py-colmajor/')) op.set_value('dp_type', 'cifar') op.set_value('inner_size', '24') op.set_value('gpu', '0') op.set_value('testing_freq', '25') op.set_value('train_batch_range', '1-5') op.set_value('test_batch_range', '6') op.set_value('num_epochs', n_epochs, parse=False) op.set_value('layer_def', layer_file_name) op.set_value('layer_params', param_file_name) op.set_value('save_file_override', save_file_name) convnet = ConvNet(op, None) # train for three epochs and make sure error is okay convnet.num_epochs = 3 convnet.train() logprob, error = logprob_errors(convnet.train_outputs[-1]) if not (error > 0 and error < 0.85): # should get at most 85% error after three epochs print "\naborted (%s, %s)" % (logprob, error) return logprob, error # train for full epochs convnet.num_epochs = n_epochs convnet.train() logprob, error = logprob_errors(convnet.get_test_error()) print "\nfinished (%s, %s)" % (logprob, error) return logprob, error except RuntimeError: print "\nerrored at epoch %d" % (convnet.epoch) return np.inf, 1.0 finally: if convnet is not None: convnet.destroy_model_lib() print "\n" # end any pending lines to ensure flush sys.stdout.flush() sys.stdout.close() sys.stdout = real_stdout