예제 #1
0
def main():

    # INPUT DATA

    source = dn.loaders.mnist()
    source.read_data()
    source.partition()

    # SPECIFY ARCHITECTURE

    mod = dn.stack()
    mod.set_arch(arch)
    mod.set_transfn(transfn)
    for i in range(len(arch)):
        j = len(arch) - i - 1
        if i < len(arch) // 2:
            mod[i].set_weights(weights)
        else:
            mod[i].set_weights(mod[j].ret_params('weights'), **weights_kwds)

    # SPECIFY NETWORK

    net = dn.network(net_name)
    net.set_subnets(mod)
    net.set_inputs(input_dims)

    # SPECIFY SUPERVISOR AND TRAINING

    sup = dn.supervisor()
    sup.set_labels(dtype='float32')
    sup.set_errorq('mse')
    sup.set_work(net)
    sup.set_optimiser(optimiser, **optimiser_kwds)
    sup.add_schedule(learning_rate)

    # TRAIN AND TEST

    now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
    t0 = time()
    with sup.call_session(write_dir + net_name + "_" + now):
        for i in range(n_epochs):
            while True:
                data = source.next_batch('train', batch_size)
                if data:
                    data = data[0].reshape([batch_size, -1])
                    sup.train(data, data)
                else:
                    break
            data = source.next_batch('test')
            data = data[0].reshape([len(data[0]), -1])
            summary_str = sup.test(data, data)
            print("".join(["Epoch {} ({} s): ",
                           summary_str]).format(str(i),
                                                str(round(time() - t0))))
예제 #2
0
def main():

  # INPUT DATA

  source = dn.loaders.mnist()
  source.read_data()
  source.partition(seed=seed)

  # SPECIFY ARCHITECTURE

  mod = dn.stack()
  mod.set_arch(arch)
  mod.set_transfn(transfn)
  mod.set_dropout(dropout)
  mod.set_normal(normal, **normal_kwds)
  mod.set_reguln(reguln, **reguln_kwds)

  # SPECIFY NETWORK

  net = dn.network(net_name)
  net.set_subnets(mod)
  net.set_inputs(input_dims)

  # SPECIFY SUPERVISOR AND TRAINING

  sup = dn.supervisor()
  sup.set_optimiser(optimiser, **optimiser_kwds)
  sup.set_work(net)
  sup.add_schedule(learning_rate)
  sup.add_schedule(0.1*learning_rate)
  sup.add_schedule(0.01*learning_rate)
  index = sup.add_schedule(0.01*learning_rate)
  sup.set_schedule(index, False) # disable dropout

  # TRAIN AND TEST

  now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
  t0 = time()
  with sup.call_session(write_dir+net_name+"_"+now, seed=seed):
    for i in range(n_epochs):
      if not (i % epochs_per_schedule):
        sup.use_schedule(i // epochs_per_schedule)
      while True:
        data = source.next_batch('train', batch_size)
        if data:
          sup.train(*data)
        else:
          break
      data = source.next_batch('test')
      summary_str = sup.test(*data)
      print("".join(["Epoch {} ({} s): ", summary_str]).format(str(i), str(round(time()-t0))))
예제 #3
0
def main():

    # INPUT DATA

    source = dn.loaders.mnist()
    source.read_data()
    source.partition()

    # SPECIFY ARCHITECTURE

    mod = dn.level()
    mod.set_arch(arch)
    mod.set_transfn(transfn)
    mod.set_opverge(True)

    # SPECIFY NETWORK

    net = dn.network(net_name)
    net.set_subnets(mod)
    net.set_inputs(input_dims)

    # SPECIFY SUPERVISOR AND TRAINING

    sup = dn.supervisor()
    sup.set_work(net)
    sup.add_schedule(learning_rate)
    sup.set_costfn('mse')

    # TRAIN AND TEST

    now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
    t0 = time()
    with sup.call_session(write_dir + net_name + "_" + now):
        for i in range(n_epochs):
            while True:
                data = source.next_batch('train', batch_size)
                if data:
                    sup.train(*data)
                else:
                    break
            data = source.next_batch('test')
            summary_str = sup.test(*data)
            print("".join(["Epoch {} ({} s): ",
                           summary_str]).format(str(i),
                                                str(round(time() - t0))))
예제 #4
0
def main():
  seed = 42

  # INPUT DATA

  source = dn.loaders.cifar10()
  source.read_data(gcn=gcn, zca=zca, gcn_within_depth=gcn_within_depth)
  source.partition(seed=seed)

  # SPECIFY ARCHITECTURE

  mod = dn.stack()
  mod.set_arch(arch)
  mod.set_transfn(transfn)
  mod.set_dropout(dropout)
  mod.set_normal(normal, **normal_kwds)
  mod.set_reguln(reguln, **reguln_kwds)

  # SPECIFY NETWORK

  net = dn.network(net_name)
  net.set_subnets(mod)
  net.set_inputs(input_dims)

  # SPECIFY SUPERVISOR AND TRAINING

  sup = dn.supervisor()
  sup.set_optimiser(optimiser, **optimiser_kwds)
  sup.set_work(net)
  for epoch in sorted(list(learning_rates.keys())):
    index = sup.add_schedule(learning_rates[epoch])
    if index == len(learning_rates) - 1:
      sup.set_schedule(index, False) # disable dropout

  # CHECK FOR RESTOREPOINT

  modfiler = dn.helpers.model_filer(write_dir, net_name)
  restore_point = modfiler.interview()
  if restore_point is not None:
    seed = restore_point

  # TRAIN AND TEST

  now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
  log_out = None if write_dir is None else write_dir+net_name+"_"+now
  mod_out = None if write_dir is None else log_out + "/" + net_name
  schedule = -1
  epoch_0 = 0
  t0 = time()

  with sup.call_session(log_out, seed=seed):
    if restore_point is not None:
      epoch_0 = int(np.ceil(float(sup.progress[1])/float(source.sets['train']['support'])))
      for i in range(epoch_0):
        if i in learning_rates:
          schedule += 1
          sup.use_schedule(schedule)
    for i in range(epoch_0, n_epochs):
      if i in learning_rates:
        schedule += 1
        sup.use_schedule(schedule)
      while True:
        data = source.next_batch('train', batch_size, \
                                 rand_flip=rand_flip, rand_crop=rand_crop)
        if not data:
          break
        sup.train(*data)
      data = source.next_batch('test')
      summary_str = sup.test(*data, split=test_split)
      print("".join(["Epoch {} ({} s): ", summary_str]).format(str(i), str(round(time()-t0))))
      if i and mod_out is not None:
        if not(i % save_interval) or i == n_epochs -1:
          sup.save(mod_out)