def main(): # INPUT DATA source = dn.loaders.mnist() source.read_data() source.partition() # SPECIFY ARCHITECTURE mod = dn.stack() mod.set_arch(arch) mod.set_transfn(transfn) for i in range(len(arch)): j = len(arch) - i - 1 if i < len(arch) // 2: mod[i].set_weights(weights) else: mod[i].set_weights(mod[j].ret_params('weights'), **weights_kwds) # SPECIFY NETWORK net = dn.network(net_name) net.set_subnets(mod) net.set_inputs(input_dims) # SPECIFY SUPERVISOR AND TRAINING sup = dn.supervisor() sup.set_labels(dtype='float32') sup.set_errorq('mse') sup.set_work(net) sup.set_optimiser(optimiser, **optimiser_kwds) sup.add_schedule(learning_rate) # TRAIN AND TEST now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S") t0 = time() with sup.call_session(write_dir + net_name + "_" + now): for i in range(n_epochs): while True: data = source.next_batch('train', batch_size) if data: data = data[0].reshape([batch_size, -1]) sup.train(data, data) else: break data = source.next_batch('test') data = data[0].reshape([len(data[0]), -1]) summary_str = sup.test(data, data) print("".join(["Epoch {} ({} s): ", summary_str]).format(str(i), str(round(time() - t0))))
def main(): # INPUT DATA source = dn.loaders.mnist() source.read_data() source.partition(seed=seed) # SPECIFY ARCHITECTURE mod = dn.stack() mod.set_arch(arch) mod.set_transfn(transfn) mod.set_dropout(dropout) mod.set_normal(normal, **normal_kwds) mod.set_reguln(reguln, **reguln_kwds) # SPECIFY NETWORK net = dn.network(net_name) net.set_subnets(mod) net.set_inputs(input_dims) # SPECIFY SUPERVISOR AND TRAINING sup = dn.supervisor() sup.set_optimiser(optimiser, **optimiser_kwds) sup.set_work(net) sup.add_schedule(learning_rate) sup.add_schedule(0.1*learning_rate) sup.add_schedule(0.01*learning_rate) index = sup.add_schedule(0.01*learning_rate) sup.set_schedule(index, False) # disable dropout # TRAIN AND TEST now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S") t0 = time() with sup.call_session(write_dir+net_name+"_"+now, seed=seed): for i in range(n_epochs): if not (i % epochs_per_schedule): sup.use_schedule(i // epochs_per_schedule) while True: data = source.next_batch('train', batch_size) if data: sup.train(*data) else: break data = source.next_batch('test') summary_str = sup.test(*data) print("".join(["Epoch {} ({} s): ", summary_str]).format(str(i), str(round(time()-t0))))
def main(): # INPUT DATA source = dn.loaders.mnist() source.read_data() source.partition() # SPECIFY ARCHITECTURE mod = dn.level() mod.set_arch(arch) mod.set_transfn(transfn) mod.set_opverge(True) # SPECIFY NETWORK net = dn.network(net_name) net.set_subnets(mod) net.set_inputs(input_dims) # SPECIFY SUPERVISOR AND TRAINING sup = dn.supervisor() sup.set_work(net) sup.add_schedule(learning_rate) sup.set_costfn('mse') # TRAIN AND TEST now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S") t0 = time() with sup.call_session(write_dir + net_name + "_" + now): for i in range(n_epochs): while True: data = source.next_batch('train', batch_size) if data: sup.train(*data) else: break data = source.next_batch('test') summary_str = sup.test(*data) print("".join(["Epoch {} ({} s): ", summary_str]).format(str(i), str(round(time() - t0))))
def main(): seed = 42 # INPUT DATA source = dn.loaders.cifar10() source.read_data(gcn=gcn, zca=zca, gcn_within_depth=gcn_within_depth) source.partition(seed=seed) # SPECIFY ARCHITECTURE mod = dn.stack() mod.set_arch(arch) mod.set_transfn(transfn) mod.set_dropout(dropout) mod.set_normal(normal, **normal_kwds) mod.set_reguln(reguln, **reguln_kwds) # SPECIFY NETWORK net = dn.network(net_name) net.set_subnets(mod) net.set_inputs(input_dims) # SPECIFY SUPERVISOR AND TRAINING sup = dn.supervisor() sup.set_optimiser(optimiser, **optimiser_kwds) sup.set_work(net) for epoch in sorted(list(learning_rates.keys())): index = sup.add_schedule(learning_rates[epoch]) if index == len(learning_rates) - 1: sup.set_schedule(index, False) # disable dropout # CHECK FOR RESTOREPOINT modfiler = dn.helpers.model_filer(write_dir, net_name) restore_point = modfiler.interview() if restore_point is not None: seed = restore_point # TRAIN AND TEST now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S") log_out = None if write_dir is None else write_dir+net_name+"_"+now mod_out = None if write_dir is None else log_out + "/" + net_name schedule = -1 epoch_0 = 0 t0 = time() with sup.call_session(log_out, seed=seed): if restore_point is not None: epoch_0 = int(np.ceil(float(sup.progress[1])/float(source.sets['train']['support']))) for i in range(epoch_0): if i in learning_rates: schedule += 1 sup.use_schedule(schedule) for i in range(epoch_0, n_epochs): if i in learning_rates: schedule += 1 sup.use_schedule(schedule) while True: data = source.next_batch('train', batch_size, \ rand_flip=rand_flip, rand_crop=rand_crop) if not data: break sup.train(*data) data = source.next_batch('test') summary_str = sup.test(*data, split=test_split) print("".join(["Epoch {} ({} s): ", summary_str]).format(str(i), str(round(time()-t0)))) if i and mod_out is not None: if not(i % save_interval) or i == n_epochs -1: sup.save(mod_out)