Example #1
0
def main(args):
    params = {
        'num_pools': args.num_pools,
        'num_epochs': args.num_epochs,
        'drop_prob': args.drop_prob,
        'num_chans': args.num_chans,
        'batch_size': args.batch_size,
        'lr': args.lr,
        'lr_step_size': args.lr_step_size,
        'lr_gamma': args.lr_gamma,
        'weight_decay': args.weight_decay,
        'checkpoint': str(args.checkpoint_dir),
        'resolution': args.resolution,
        'save_summary_steps': args.save_summary_steps,
        'save_checkpoints_secs': args.save_checkpoints_secs,
        'save_checkpoints_steps': args.save_checkpoints_steps,
        'keep_checkpoint_max': args.keep_checkpoint_max,
        'log_step_count_steps': args.log_step_count_steps,
        'warm_start_from': args.warm_start_from,
        'use_seed': False,
        'resolution': args.resolution,
        'data_set': args.data_set,
        'limit': args.limit,
        'unpool': args.unpool,
        'optimizer': args.optimizer,
        'loss': args.loss,
    }
    if args.export:
        export(args.checkpoint_dir, params)
        return
    if not tf.gfile.Exists(args.checkpoint_dir):
        tf.gfile.MakeDirs(args.checkpoint_dir)
    if args.worker:
        klclient.update_task_info({
            'num-pools': args.num_pools,
            'drop-prob': args.drop_prob,
            'num-chans': args.num_chans,
            'batch-size': args.batch_size,
            'lr.lr': args.lr,
            'lr.lr-step-size': args.lr_step_size,
            'lr.lr-gamma': args.lr_gamma,
            'weight-decay': args.weight_decay,
            'checkpoint_path': str(args.checkpoint_dir),
            'resolution': args.resolution
        })
        train('train', args.checkpoint_dir, params)
    else:
        cluster = {
            'chief': ['fake_worker1:2222'],
            'ps': ['fake_ps:2222'],
            'worker': ['fake_worker2:2222']
        }
        os.environ['TF_CONFIG'] = json.dumps({
            'cluster': cluster,
            'task': {
                'type': 'evaluator',
                'index': 0
            }
        })
        train('eval', args.checkpoint_dir, params)
Example #2
0
def export(checkpoint_dir, params):
    if os.environ.get('TRAINING_DIR', '') != '' and os.environ.get('BASE_TASK_BUILD_ID', '') != '':
        checkpoint_dir = os.environ['TRAINING_DIR'] + '/' + os.environ['BASE_TASK_BUILD_ID']
    conf = tf.estimator.RunConfig(
        model_dir=checkpoint_dir,
    )
    params['batch_size']=1
    feature_placeholders = {
        'image': tf.placeholder(tf.float32, [1, None, None, 3], name='image'),
    }
    receiver = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_placeholders)
    net = BoxUnet(
        params=params,
        model_dir=checkpoint_dir,
        config=conf,
    )
    export_path = net.export_savedmodel(
        checkpoint_dir,
        receiver,
    )
    export_path = export_path.decode("utf-8")
    klclient.update_task_info({'model_path': export_path})