示例#1
0
def train(iterations=80000, summary_interval=100, batch=32):
    tf.logging.set_verbosity(tf.logging.INFO)
    placeholders = create_placeholders()
    global_step, network_a, network_b = get_network(*placeholders[-4:], placeholders[5])
    with Session(True, True, global_step) as sess:
        with Driver() as driver:
            try:
                last_save = timer()
                array = []
                buffer = score_buffer()
                step = 0
                time = 1
                for _ in range(iterations):
                    if len(array) < batch*10:
                        get_input(driver, sess.session, network_a, network_b, placeholders, buffer, array)
                    pre = timer()
                    _, aloss, step = sess.session.run([network_a.trainer, network_a.loss, global_step], feed_dict=get_batch_feed(array, placeholders, batch, batch//2))
                    _, bloss = sess.session.run([network_b.trainer, network_b.loss], feed_dict=get_batch_feed(array, placeholders, batch, batch//2))
                    time = 0.9*time + 0.11 *(timer()-pre)
                    if step%10 == 0:
                        print("Training step: %i, Loss A: %.3f, Loss B: %.3f (%.2f s)  "%(step, aloss, bloss, time), end='\r')
                    if step%summary_interval == 0:
                        sess.save_summary(step, get_batch_feed(array, placeholders, batch, batch//2))
                        print()
                    if timer() - last_save > 1800:
                        sess.save_network()
                        last_save = timer()
            except (KeyboardInterrupt, StopIteration):
                print("\nStopping the training")
示例#2
0
def test_fun(dataset, opts):
    test_arg_list = []
    for test_rot in [None]:
        for test_flip in [None]:
            test_arg_list.append((test_rot, test_flip))
    for test_arg in test_arg_list:
        #for test_arg in [(None, None)]:
        opts.test_arg = test_arg
        # dataset and iterator
        dataset_val = dataset.get_dataset(opts)
        iterator = dataset_val.make_one_shot_iterator()
        volume, joint_coord, shape, data_num = iterator.get_next()
        inputs = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, 1 + opts.temporal * opts.nJoint])
        dn_p = 0

        # network
        outputs, _ = get_network(inputs, opts)

        # save and load
        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope=opts.network))

        start = time.time()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print('restore trained model')
            saver.restore(
                sess,
                os.path.join(opts.output_path, opts.name,
                             'model%d.ckpt' % opts.epochs))
            print('test start')
            res = []
            while True:
                try:
                    v, joint, s, dn = sess.run(
                        [volume, joint_coord, shape, data_num])
                    if opts.temporal:
                        if np.squeeze(dn) != dn_p:
                            output_val = first_heatmap_p(joint, s, opts)
                            dn_p = np.squeeze(dn)
                        else:
                            output_val = get_heatmap_p(output_val, opts)
                        output_val = sess.run(outputs[-1],
                                              feed_dict={
                                                  inputs:
                                                  np.concatenate(
                                                      [v, output_val], axis=-1)
                                              })
                    else:
                        output_val = sess.run(outputs[-1],
                                              feed_dict={inputs: v})
                    res.append(test_result(output_val, joint, s, dn, opts))
                except tf.errors.OutOfRangeError:
                    break
            save_test_result(res, opts)
            reset_dict()
        tf.reset_default_graph()
        print("test end, elapsed time: ", time.time() - start)
def test_values(opts):
  # Get data and network
  dataset = data_util.datasets.get_dataset(opts)
  network = model.get_network(opts, opts.arch)
  # Sample and network output
  # sample = dataset.load_batch('test', repeat=1)
  # output_graph = network(sample)
  # output_sim = get_test_output_sim(opts, output_graph)
  # losses = get_tf_test_losses(opts, sample, output_sim)
  # tf_evals = [ losses, output_sim, sample['true_match'] ] 
  sample, placeholders = dataset.get_placeholders()
  output_graph = network(sample)
  output_sim = get_test_output_sim(opts, output_graph)
  tf_evals = [ output_sim, sample['true_match'] ] 

  # Tensorflow and logging operations
  disp_string = \
      '{idx:06d}: {{' \
      'time: {time:.03e}, ' \
      'l1: {l1:.03e}, ' \
      'l2: {l2:.03e}, ' \
      'ssame: {{ m: {ssame_m:.03e}, std: {ssame_std:.03e} }}, ' \
      'sdiff: {{ m: {sdiff_m:.03e}, std: {sdiff_std:.03e} }}, ' \
      'roc: {roc:.03e}, ' \
      'p_r: {p_r:.03e}, ' \
      '}}' # End of lines

  # Build session
  glob_str = os.path.join(opts.dataset_params.data_dir, 'np_test', '*npz')
  npz_files = sorted(glob.glob(glob_str))
  # vars_restore = [ v for v in tf.get_collection('weights') ] + \
  #                [ v for v in tf.get_collection('biases') ]
  vars_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  # print(vars_restore)
  saver = tf.train.Saver(vars_restore)
  with open(os.path.join(opts.save_dir, 'test_output.log'), 'a') as log_file:
    with build_test_session(opts) as sess:
      best_loss_ckpt = os.path.join(os.path.abspath(opts.save_dir), 'best-loss-model')
      if os.path.exists(best_loss_ckpt + '.meta'):
        saver.restore(sess, best_loss_ckpt)
      else:
        saver.restore(sess, tf.train.latest_checkpoint(opts.save_dir))
      # for i in range(opts.dataset_params.sizes['test']):
      for i, npz_file in enumerate(npz_files):
        start_time = time.time()
        with open(npz_file, 'rb') as f:
          npz_ld = dict(np.load(f))
        feed_dict = dataset.get_feed_dict(placeholders, npz_ld)
        stime = time.time()
        output_sim_, matches_ = sess.run(tf_evals, feed_dict=feed_dict)
        etime = time.time()
        values_ = {'idx' : i, 'time': etime - stime }
        values_.update(get_np_losses(opts, output_sim_, matches_[0]))
        dstr = disp_string.format(**values_)
        end_time = time.time()
        print(dstr + ' ({:.03f})'.format(end_time - start_time))
        # print(dstr)
        log_file.write(dstr)
        log_file.write('\n')
示例#4
0
def train(opts):
  """Train the network
  Input: opts (options) - object with all relevant options stored
  Output: None
  Saves all output in opts.save_dir, given by the user. For how to modify the
  training procedure, look at options.py
  """
  # Get data and network
  dataset = data_util.datasets.get_dataset(opts)
  network = model.get_network(opts, opts.arch)
  # Training
  with tf.device('/cpu:0'):
    if opts.load_data:
      sample = dataset.load_batch('train')
    else:
      sample = dataset.gen_batch('train')
  output = network(sample['Laplacian'], sample['InitEmbeddings'])
  loss = get_loss(opts, sample, output, name='train')
  train_op = get_train_op(opts, loss)
  # Testing
  test_data = get_test_dict(opts, dataset, network)

  # Tensorflow and logging operations
  step = 0
  train_steps, train_time, test_freq_steps, test_freq = get_intervals(opts)
  trainstr = "global step {}: loss = {} ({:.04} sec/step, time {:.04})"
  tf.logging.set_verbosity(tf.logging.INFO)
  # Build session
  with build_session(opts) as sess:
    # Train loop
    for run in range(opts.num_runs):
      stime = time.time()
      ctime = stime
      ttime = stime
      # Main loop
      while step != train_steps and ctime - stime <= train_time:
        start_time = time.time()
        _, loss_ = sess.run([ train_op, loss ])
        ctime = time.time()
        # Check for logging
        if (step % opts.log_steps) == 0:
          log(trainstr.format(step,
                                loss_,
                                ctime - start_time,
                                ctime - stime))
        # Check if time to evaluate test
        if ((test_freq_steps and step % test_freq_steps == 0) or \
            (ctime - ttime > test_freq)):
          raw_sess = sess.raw_session()
          run_test(opts, raw_sess, test_data)
          ttime = time.time()
        step += 1
示例#5
0
def test_fun(opts):
    v = read_nifti('../rawdata/022618/022618_000.nii.gz')
    joint_coord = sio.loadmat(
        '../rawdata/022618/022618.mat')['joint_coord'].reshape(
            (-1, 15, 3)).transpose((0, 2, 1))

    lc, rc = 2, 3

    x_l, y_l, z_l = joint_coord[0, 0, lc] - 1, joint_coord[
        0, 1, lc] - 1, joint_coord[0, 2, lc] - 1
    x_r, y_r, z_r = joint_coord[0, 0, rc] - 1, joint_coord[
        0, 1, rc] - 1, joint_coord[0, 2, rc] - 1
    xv, yv, zv = np.meshgrid(np.arange(v.shape[1]), np.arange(v.shape[0]),
                             np.arange(v.shape[2]))

    val_mask_l = np.floor((xv - x_l)**2 + (yv - y_l)**2 +
                          (zv - z_l)**2) <= 3**2
    val_mask_r = np.floor((xv - x_r)**2 + (yv - y_r)**2 +
                          (zv - z_r)**2) <= 3**2

    inputs = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    outputs, _ = get_network(inputs, opts)
    saver = tf.train.Saver(var_list=tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=opts.network))
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(
            sess,
            os.path.join(opts.output_path, opts.name,
                         'model%d.ckpt' % opts.epochs))
        for r in range(1, 50):
            mask_l = np.floor((xv - x_l)**2 + (yv - y_l)**2 +
                              (zv - z_l)**2) <= r**2
            mask_r = np.floor((xv - x_r)**2 + (yv - y_r)**2 +
                              (zv - z_r)**2) <= r**2
            output_val = sess.run(outputs[-1],
                                  feed_dict={
                                      inputs:
                                      np.expand_dims(
                                          np.expand_dims(v * mask_r, 0), -1)
                                  })
            output_val = np.squeeze(output_val)
            print(r, np.max(output_val[:, :, :, lc] * val_mask_r),
                  np.max(output_val[:, :, :, rc] * val_mask_r))

    print(
        np.sqrt(
            sum((joint_coord[0, :, :] -
                 np.reshape(np.array([x_r + 1, y_r + 1, z_r + 1]), (3, 1)))**2,
                0)))
示例#6
0
def test_values(opts):
    """Run testing on the network
  Input: opts (options) - object with all relevant options stored
  Output: None
  Saves all output in opts.save_dir, given by the user. It loads the saved
  configuration from the options.yaml file in opts.save_dir, so only the
  opts.save_dir needs to be specified. Will test and save out all test values
  in the test set into test_output.log in opts.save_dir
  """
    # Get data and network
    dataset = data_util.datasets.get_dataset(opts)
    network = model.get_network(opts, opts.arch)
    # Sample
    sample = dataset.get_placeholders()
    print(sample)
    output = network(sample['Laplacian'], sample['InitEmbeddings'])
    losses = get_test_losses(opts, sample, output)

    # Tensorflow and logging operations
    disp_string =  '{:06d} Errors: ' \
                   'L1: {:.03e},  L2: {:.03e}, BCE: {:.03e}, ' \
                   'Same sim: {:.03e} +/- {:.03e}, ' \
                   'Diff sim: {:.03e} +/- {:.03e}, ' \
                   'Time: {:.03e}, '

    # Build session
    glob_str = os.path.join(opts.dataset_params.data_dir, 'np_test', '*npz')
    npz_files = sorted(glob.glob(glob_str))
    vars_restore = [ v for v in tf.get_collection('weights') ] + \
                   [ v for v in tf.get_collection('biases') ]
    print(vars_restore)
    saver = tf.train.Saver(vars_restore)
    with open(os.path.join(opts.save_dir, 'test_output.log'), 'a') as log_file:
        with build_test_session(opts) as sess:
            saver.restore(sess, tf.train.latest_checkpoint(opts.save_dir))
            for i, npz_file in enumerate(npz_files):
                sample_ = {
                    k: np.expand_dims(v, 0)
                    for k, v in np.load(npz_file).items()
                }
                start_time = time.time()
                vals = sess.run(losses,
                                {sample[k]: sample_[k]
                                 for k in sample.keys()})
                end_time = time.time()
                dstr = disp_string.format(i, *vals, end_time - start_time)
                print(dstr)
                log_file.write(dstr)
                log_file.write('\n')
示例#7
0
def load_model(name,epoch,shape,network_type,ctx):

    data_sign = ['left','right','left_downsample','right_downsample','label','LinearRegression_label','gt']
    net,args,aux = mx.model.load_checkpoint(name,epoch)
    keys = net.list_arguments()
    net = get_network(network_type)
    executor = net.simple_bind(ctx=ctx,grad_req='add',left = shape,right= shape)
    for key in executor.arg_dict:
        if key in  data_sign:
            executor.arg_dict[key][:] = mx.nd.zeros((executor.arg_dict[key].shape),ctx)
        else:
            if key in args:
                executor.arg_dict[key][:] = args[key]
            else:
                init(key,executor.arg_dict[key])
    return net,executor
示例#8
0
def learn(image,
          variables,
          example,
          score,
          iterations=10000,
          summary_interval=100):
    """
        Learn to drive from examples
    """
    try:
        global_step, network_a, network_b = get_network(
            image, variables, example, score, True)
        with Session(True, True, global_step) as sess:
            last_save = timer()
            step = 1
            time = 1
            while step < iterations:
                pre = timer()
                _, aloss, step = sess.session.run(
                    [network_a.trainer, network_a.loss, global_step])
                _, bloss = sess.session.run(
                    [network_b.trainer, network_b.loss])
                if step % summary_interval == 0:
                    sess.save_summary(step)
                    print()
                time = 0.9 * time + 0.1 * (timer() - pre)
                if step % 10 == 0:
                    print(
                        "Training step: %i, Loss A: %.3f, Loss B: %.3f (%.2f s)  "
                        % (step, aloss, bloss, time),
                        end='\r')
                if timer() - last_save > 1800:
                    sess.save_network()
                    last_save = timer()
            aloss_tot = 0
            bloss_tot = 0
            for i in range(10):
                aloss, bloss = sess.session.run(
                    [network_a.loss, network_b.loss])
                aloss_tot += aloss
                bloss_tot += bloss
            print("\nFinal loss A: %.3f, Final loss B: %.3f" %
                  (aloss_tot / 10, bloss_tot / 10))
    except (KeyboardInterrupt, StopIteration):
        pass
    finally:
        print("\nStopping the training")
def train(opts):
  # Get data and network
  dataset = data_util.datasets.get_dataset(opts)
  network = model.get_network(opts, opts.arch)
  # Training
  with tf.device('/cpu:0'):
    if opts.load_data:
      sample = dataset.load_batch('train')
    else:
      sample = dataset.gen_batch('train')
  output = network(sample['Laplacian'], sample['InitEmbeddings'])
  loss = get_loss(opts, sample, output, name='train')
  train_op = get_train_op(opts, loss)
  # Testing
  test_data = get_test_dict(opts, dataset, network)

  # Tensorflow and logging operations
  step = 0
  train_steps, train_time, test_freq_steps, test_freq = get_intervals(opts)
  trainstr = "global step {}: loss = {} ({:.04} sec/step, time {:.04})"
  tf.logging.set_verbosity(tf.logging.INFO)
  # Build session
  with build_session(opts) as sess:
    # Train loop
    for run in range(opts.num_runs):
      stime = time.time()
      ctime = stime
      ttime = stime
      while step != train_steps and ctime - stime <= train_time:
        start_time = time.time()
        _, loss_ = sess.run([ train_op, loss ])
        ctime = time.time()
        if (step % opts.log_steps) == 0:
          log(trainstr.format(step,
                                loss_,
                                ctime - start_time,
                                ctime - stime))
        if ((test_freq_steps and step % test_freq_steps == 0) or \
            (ctime - ttime > test_freq)):
          raw_sess = sess.raw_session()
          run_test(opts, raw_sess, test_data)
          ttime = time.time()
        step += 1
示例#10
0
def drive():
    """
        Drive a car, alternating between the networks
    """
    tf.logging.set_verbosity(tf.logging.INFO)
    imgs = tf.placeholder(tf.float32, [None, IMAGE_WIDTH*IMAGE_HEIGHT*IMAGE_DEPTH])
    vars = tf.placeholder(tf.float32, [None, VARIABLE_COUNT])
    _, neta, netb = get_network(tf.reshape(imgs, [-1, IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_DEPTH]), vars, training=False)
    with Session(False, False) as sess:
        with Driver() as driver:
            def inout(h, v):
                print("Driving  |  h: %+.2f  v: %+.2f"%(h,v), end='\r')
                x, v, y, s = driver.drive(h, v)
                return { imgs: [x], vars: [v] }
            try:
                h = 0
                v = 1
                while True:
                    h, v, _ = sess.session.run(neta.output, feed_dict=inout(h, v))[0]
                    h, v, _ = sess.session.run(netb.output, feed_dict=inout(h, v))[0]
            except (KeyboardInterrupt, StopIteration):
                pass
示例#11
0
 def train(self):
     dataset = data_util.datasets.get_dataset(opts)
     with tf.device('/cpu:0'):
         sample = dataset.load_batch('train')
         test_sample = dataset.load_batch('test')
     # Build network
     network = model.get_network(opts, self.arch)
     # Get training output
     out_graph = network(sample)
     output_sim = self.get_output_sim(out_graph)
     loss = self.get_loss(sample, output_sim, test_mode=False, name='train')
     train_op = self.get_train_op(loss)
     # Get test output
     self.build_test(test_sample, network)
     # Start training
     trainstr = "local step {}: loss = {} ({:.04} sec/step, time {:.04})"
     step = 0
     tf.logging.set_verbosity(tf.logging.INFO)
     with self.build_session() as sess:
         raw_sess = sess.raw_session()
         stime = time.time()
         ctime = stime
         ttime = stime
         # while step != train_steps and ctime - stime <= train_time:
         while self.keep_training(step, ctime - stime):
             step_time = time.time()
             _, loss_ = sess.run([train_op, loss])
             ctime = time.time()
             if (step % opts.log_steps) == 0:
                 self.log(
                     trainstr.format(step, loss_, ctime - step_time,
                                     ctime - stime))
             if self.test_now(step, ctime - ttime):
                 self.run_test(raw_sess, step)
                 ttime = time.time()
             step += 1
示例#12
0
    parser.add_argument('--lr',action='store',dest='lr',type=float)
    parser.add_argument('--l',action='store',dest='low',type=int)
    parser.add_argument('--h',action='store',dest='high',type=int)
    cmd = parser.parse_args()
    #cmd.con 不是指epoch,是指第几个轮,200个batch 为1轮 。 kitty dataset 跑一次epoch 需要3 小时

    lr = cmd.lr
    batch_size = 10000
    s1 = (batch_size,3,13,13)
    s2 = (batch_size,3,7,7)
    ctx = mx.gpu(3) 
    data_sign = ['left','right','left_downsample','right_downsample','label','LinearRegression_label','gt']
    
    if cmd.con == -1:
        #重新训练
        net = get_network('not fully')
        executor = net.simple_bind(ctx=ctx,grad_req='add',left = s1,right= s1)
        keys  = net.list_arguments()
        grads = dict(zip(net.list_arguments(),executor.grad_arrays))
        args  = dict(zip(keys,executor.arg_arrays))
        auxs  = dict(zip(keys,executor.arg_arrays))
        args['gt'] = mx.nd.zeros((batch_size,),ctx)
        logging.info("complete network architecture design")
    
    else:
        #继续之前的训练
        net,executor = load_model('stereo',cmd.con,s1,'not fully',ctx)
        keys = net.list_arguments()
        grads = dict(zip(keys,executor.grad_arrays))
        args  = dict(zip(keys,executor.arg_arrays))
        auxs  = dict(zip(keys,executor.arg_arrays))
示例#13
0
    parser = argparse.ArgumentParser()  
    parser.add_argument('--continue',action='store',dest='con',type=int)
    parser.add_argument('--lr',action='store',dest='lr',type=float)
    cmd = parser.parse_args()
    #cmd.con 不是指epoch,是指第几个轮,200个batch 为1轮 。 kitty dataset 跑一次epoch 需要3 小时

    lr = cmd.lr
    batch_size = 1235
    s1 = (batch_size,3,13,13)
    s2 = (batch_size,3,7,7)
    ctx = mx.gpu(3) 
    data_sign = ['left','right','left_downsample','right_downsample','label','LinearRegression_label','gt']
    
    if cmd.con == -1:
        #重新训练
        net = get_network('fully')
        executor = net.simple_bind(ctx=ctx,grad_req='add',left = s1,right= s1)
        keys  = net.list_arguments()
        grads = dict(zip(net.list_arguments(),executor.grad_arrays))
        args  = dict(zip(keys,executor.arg_arrays))
        auxs  = dict(zip(keys,executor.arg_arrays))
        args['gt'] = mx.nd.zeros((batch_size,),ctx)
        logging.info("complete network architecture design")
    
    else:
        #继续之前的训练
        net,executor = load_model('stereo',cmd.con,s1,'fully',ctx)
        keys = net.list_arguments()
        grads = dict(zip(keys,executor.grad_arrays))
        args  = dict(zip(keys,executor.arg_arrays))
        auxs  = dict(zip(keys,executor.arg_arrays))
示例#14
0
                    weights_name = args.net_type + str(
                        args.depth) + "-" + str(epoch).zfill(5) + ".pth"
                    torch.save(state, os.path.join(args.model_dir,
                                                   weights_name))

    time_elapsed = time.time() - since
    print("Training completed in {:.0f} min {:.0f} sec".format(
        time_elapsed // 60, time_elapsed % 60))
    print("Best validation Acc {:.2f}%".format(best_acc * 100))


if __name__ == '__main__':
    print("--Phase 0: arguments settings...")
    args = set_args()
    args.use_gpu = torch.cuda.is_available()

    print("--Phase 1: Data prepration")
    data_loader, num_class = prepare_dataset(args)
    args.num_class = num_class

    print("--Phase 2: Model setup")
    model_ft, file_name = get_network(args)
    model_ft = reset_classifier(model_ft, args)
    if args.use_gpu:
        model_ft.cuda(args.device_id)
        import torch.backends.cudnn as cudnn
        cudnn.benchmark = True

    print("--Phase 3: Training Model")
    train_model(model_ft, data_loader, args)
示例#15
0
    for _ in range(1000):
        img, tag = dataloader.next()
        intermediate = infer(branchyNet, 'client', ep, pp, qb, img)
        result = infer(branchyNet, 'server', ep, pp, qb, intermediate)
        prob = torch.exp(result).detach().numpy().tolist()[0]
        pred = (prob.index(max(prob)), max(prob))


def get_latency(branchyNet, img):
    start_time = time.time()
    infer_main_branch(branchyNet, img)
    end_time = time.time()
    return end_time - start_time


if __name__ == "__main__":
    branchyNet = get_network()
    load_model(branchyNet)
    branchyNet.testing()

    dataloader = get_test_data()

    img, tag = dataloader.next()
    print(img.size())

    latency = 0
    for _ in range(10):
        img, tag = dataloader.next()
        latency += get_latency(branchyNet, img)
    print(latency / 10)
示例#16
0
def train_fun(dataset, opts):
    # dataset and iterator
    dataset_train, dataset_val = dataset.get_dataset(opts)
    iterator = tf.data.Iterator.from_structure(dataset_train.output_types,
                                               dataset_train.output_shapes)
    if opts.train_like_test:
        volume, label = iterator.get_next()
        inputs = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, 1 + opts.temporal * opts.nJoint])
        labels = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, opts.nJoint])
    else:
        inputs, labels = iterator.get_next()

    # network
    outputs, training = get_network(inputs, opts)

    # loss
    loss, mean_update_ops = mse_loss(outputs, labels, opts)

    # summary
    writer_train = tf.summary.FileWriter(
        os.path.join(opts.output_path, opts.name, 'logs', 'train'),
        tf.get_default_graph())
    writer_val = tf.summary.FileWriter(
        os.path.join(opts.output_path, opts.name, 'logs', 'val'))
    summary_op = tf.summary.merge_all()

    # varlist
    name_list = [
        ns[0] for ns in tf.train.list_variables(
            os.path.join(opts.output_path, opts.use_pretrain,
                         'pretrainmodel.ckpt'))
    ] if opts.use_pretrain != '' else []
    pretrain_list = [
        v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=opts.network)
        if v.name[:-2] in name_list
    ]
    newtrain_list = [
        v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=opts.network)
        if v.name[:-2] not in name_list
    ]
    print('pretrain var: %d, newtrain var: %d' %
          (len(pretrain_list), len(newtrain_list)))

    # optimizer
    optimizer = Optimizer(opts, pretrain_list, newtrain_list,
                          dataset_train.length)
    train_op = optimizer.get_train_op(loss)
    my_update_op = tf.group(mean_update_ops)

    # save and load
    saver = tf.train.Saver(var_list=newtrain_list + pretrain_list)
    if opts.use_pretrain != '':
        saver_pretrain = tf.train.Saver(var_list=pretrain_list)

    # main loop
    with tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                          allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())
        if opts.use_pretrain != '':
            saver_pretrain.restore(
                sess,
                os.path.join(opts.output_path, opts.use_pretrain + 'pretrain',
                             'pretrainmodel.ckpt'))
        if opts.epoch_continue > 0:
            saver.restore(
                sess,
                os.path.join(opts.output_path, opts.use_continue,
                             'model%d.ckpt' % opts.epoch_continue))
        print('training loop start')
        start_train = time.time()
        for epoch in range(opts.epoch_continue + 1, opts.epochs + 1):
            print('epoch: %d' % epoch)
            start_ep = time.time()
            # train
            print('training')
            sess.run(iterator.make_initializer(dataset_train))
            sess.run(tf.local_variables_initializer())
            while True:
                try:
                    if opts.train_like_test:
                        v, l = sess.run([volume, label])
                        if random.random() < opts.train_like_test:
                            l_p = sess.run(outputs[-1],
                                           feed_dict={
                                               training: False,
                                               inputs: v
                                           })
                            v = np.concatenate([v[:, :, :, :, :1], l_p],
                                               axis=-1)
                        summary_train, _ = sess.run([summary_op, train_op],
                                                    feed_dict={
                                                        training: True,
                                                        inputs: v,
                                                        labels: l
                                                    })
                    else:
                        summary_train, _ = sess.run([summary_op, train_op],
                                                    feed_dict={training: True})
                except tf.errors.OutOfRangeError:
                    writer_train.add_summary(summary_train, epoch)
                    break
            print('step: %d' % optimizer.get_global_step(sess))
            # validation
            print('validation')
            sess.run(iterator.make_initializer(dataset_val))
            sess.run(tf.local_variables_initializer())
            while True:
                try:
                    if opts.train_like_test:
                        v, l = sess.run([volume, label])
                        if random.random() < opts.train_like_test:
                            l_p = sess.run(outputs[-1],
                                           feed_dict={
                                               training: False,
                                               inputs: v
                                           })
                            v = np.concatenate([v[:, :, :, :, :1], l_p],
                                               axis=-1)
                        summary_val, _ = sess.run([summary_op, my_update_op],
                                                  feed_dict={
                                                      training: False,
                                                      inputs: v,
                                                      labels: l
                                                  })
                    else:
                        summary_val, _ = sess.run([summary_op, my_update_op],
                                                  feed_dict={training: False})
                except tf.errors.OutOfRangeError:
                    writer_val.add_summary(summary_val, epoch)
                    break
            # save model
            if epoch % opts.save_freq == 0 or epoch == opts.epochs:
                print('save model')
                saver.save(
                    sess,
                    os.path.join(opts.output_path, opts.name,
                                 'model%d.ckpt' % epoch))
            print("epoch end, elapsed time: %ds, total time: %ds" %
                  (time.time() - start_ep, time.time() - start_train))
        print('training loop end')
        writer_train.close()
        writer_val.close()
    opts.run = 'test'
示例#17
0
def test_fun2(opts):

    inputs = tf.placeholder(tf.float32, shape=[None, None, None, None, 1])
    outputs, _ = get_network(inputs, opts)
    saver = tf.train.Saver(var_list=tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=opts.network))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(
            sess,
            os.path.join(opts.output_path, opts.name,
                         'model%d.ckpt' % opts.epochs))
        for folder in [
                os.path.join('../newdata', f)
                for f in sorted(os.listdir('../newdata'))
        ]:
            folder_basename = os.path.basename(folder)
            print(folder_basename)
            predict_filename = os.path.join('../predict',
                                            folder_basename + '.mat')
            niinames = [
                os.path.join(folder, f) for f in sorted(os.listdir(folder))
                if f.endswith('.nii.gz')
            ]
            joint_coord = np.zeros((len(niinames), 3, 15))
            confidence = np.zeros((len(niinames), 15))
            for nf, nii in enumerate(niinames):
                v = read_nifti(nii)
                pad_width = [(int((ceil(s / 8.0) * 8 - s) / 2),
                              int(ceil((ceil(s / 8.0) * 8 - s) / 2)))
                             for s in v.shape]
                v = np.pad(v, pad_width, 'reflect')
                v = np.expand_dims(np.expand_dims(v, 0), -1)
                output_val = sess.run(outputs[-1], feed_dict={inputs: v})
                output_val = np.squeeze(output_val)
                sss = output_val.shape
                volume = output_val[pad_width[0][0]:sss[0] - pad_width[0][1],
                                    pad_width[1][0]:sss[1] - pad_width[1][1],
                                    pad_width[2][0]:sss[2] - pad_width[2][1]]
                volume[volume < 0] = 0
                s = volume[:, :, :, 0].shape
                for i in range(joint_coord.shape[-1]):
                    #print(i, np.max(volume[:,:,:,i]))
                    ind = np.unravel_index(np.argmax(volume[:, :, :, i]), s)
                    confidence[nf, i] = volume[ind[0], ind[1], ind[2], i]
                    weights = 0
                    x_p = y_p = z_p = 0
                    r = 2
                    for x in range(ind[1] - r, ind[1] + r + 1):
                        for y in range(ind[0] - r, ind[0] + r + 1):
                            for z in range(ind[2] - r, ind[2] + r + 1):
                                if 0 <= x < s[1] and 0 <= y < s[
                                        0] and 0 <= z < s[2]:
                                    weights += volume[y, x, z, i]
                                    x_p += x * volume[y, x, z, i]
                                    y_p += y * volume[y, x, z, i]
                                    z_p += z * volume[y, x, z, i]
                    if weights == 0:
                        print('weight zero')
                        joint_coord[nf, :, i] = 0
                    else:
                        joint_coord[nf, 0, i] = x_p / weights + 1
                        joint_coord[nf, 1, i] = y_p / weights + 1
                        joint_coord[nf, 2, i] = z_p / weights + 1
            sio.savemat(predict_filename, {
                'joint_coord': joint_coord,
                'confidence': confidence
            })
    print('finish')
示例#18
0
                          image_dir=args.SETTING['image_dir'],
                          label_dir=args.SETTING['label_dir'],
                          image_suffix=args.SETTING['image_suffix'],
                          label_suffix=args.SETTING['label_suffix'],
                          transform=False,
                          inference=False)

test_data = mx.gluon.data.DataLoader(sal_test,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     last_batch='keep',
                                     num_workers=2)

if args.begin_epoch > 0:
    net = get_network(args.NETWORK,
                      input_size=args.TRAIN['input_size'],
                      pretrained=False,
                      with_aspp=True)
    net.collect_params().load(args.TRAIN['pretrained'],
                              allow_missing=False,
                              ignore_extra=False)
else:
    # load network with pretrained vgg on bottom-up net
    net = get_network(args.NETWORK,
                      input_size=args.TRAIN['input_size'],
                      pretrained=True,
                      with_aspp=True)
    # init params of the rest layers
    net.aspp.collect_params().initialize(init=mx.init.MSRAPrelu())
    net.refinement1.collect_params().initialize(init=mx.init.MSRAPrelu())
    net.refinement2.collect_params().initialize(init=mx.init.MSRAPrelu())
    net.refinement3.collect_params().initialize(init=mx.init.MSRAPrelu())
示例#19
0
                              label_dir=args.SETTING['label_dir'],
                              image_set_dir=args.SETTING['image_set_dir'],
                              image_suffix=args.SETTING['image_suffix'],
                              label_suffix=args.SETTING['label_suffix'],
                              transform=False,
                              inference=True)

test_loader = mx.gluon.data.DataLoader(test_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       last_batch='keep',
                                       num_workers=1)

# load network architecture
net = get_network(args.NETWORK,
                  input_size=args.INFERENCE['input_size'],
                  pretrained=False,
                  with_aspp=True)
# load trained params
if args.test_type == 0:
    trained_model = args.INFERENCE['region_model']
else:
    trained_model = args.INFERENCE['contour_model']
net.collect_params().load(trained_model)
net.hybridize()
net.collect_params().reset_ctx(ctx)

predict(test_loader,
        test_dataset.img_info,
        net,
        ctx,
        result_path,
示例#20
0
def pretrain(opts):
    # dataset and iterator
    dataset = Dataset(opts.rawdata_path)
    dataset_train = dataset.get_dataset(opts)
    iterator = tf.data.Iterator.from_structure(dataset_train.output_types,
                                               dataset_train.output_shapes)
    v_a, v_b, label = iterator.get_next()

    # network
    outputs, training = get_network(tf.concat((v_a, v_b), axis=0), opts)

    # save and load
    saver = tf.train.Saver(var_list=tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=opts.network))

    # loss
    loss, accuracy = cross_entropy_loss(outputs, label, opts)

    # summary
    writer_train = tf.summary.FileWriter(
        os.path.join(opts.output_path, opts.time, 'logs'),
        tf.get_default_graph())
    summary_op = tf.summary.merge_all()

    # optimizer
    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(opts.lr,
                                    global_step,
                                    dataset_train.length * 67,
                                    0.1,
                                    staircase=True)
    train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        loss, global_step=global_step, colocate_gradients_with_ops=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    train_op = tf.group(update_ops + [train_op])

    # main loop
    with tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                          allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())
        print('training loop start')
        start_train = time.clock()
        for epoch in range(1, opts.epochs + 1):
            print('epoch: %d' % epoch)
            start_ep = time.clock()
            # train
            print('training')
            sess.run(iterator.make_initializer(dataset_train))
            while True:
                try:
                    summary_train, _ = sess.run([summary_op, train_op],
                                                feed_dict={training: True})
                    writer_train.add_summary(
                        summary_train, tf.train.global_step(sess, global_step))
                except tf.errors.OutOfRangeError:
                    break
            print('step: %d' % tf.train.global_step(sess, global_step))
            # save model
            if epoch % opts.save_freq == 0 or epoch == opts.epochs:
                print('save model')
                saver.save(
                    sess,
                    os.path.join(opts.output_path, opts.time,
                                 'pretrainmodel.ckpt'))
            print("epoch end, elapsed time: %ds, total time: %ds" %
                  (time.clock() - start_train, time.clock() - start_ep))
        print('training loop end')
        writer_train.close()