示例#1
0
def main(_):

    GPU_ID = FLAGS.gpu
    os.environ[
        "CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152 on stackoverflow
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)

    RUN = FLAGS.run
    EXP_DIR = FLAGS.exp_dir
    SEED = int(FLAGS.seed)
    ARCH = FLAGS.arch

    model = Model(arch=ARCH)
    train_ops = TrainOps(model, EXP_DIR, RUN, ARCH)
    train_ops.load_exp_config()

    if FLAGS.mode == 'train':
        print 'Training'
        train_ops.train(seed=SEED)

    if FLAGS.mode == 'test':
        print 'Testing'
        train_ops.test()
示例#2
0
def main(_):

    #npr.seed(int(FLAGS.seed))

    GPU_ID = FLAGS.gpu
    os.environ[
        "CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152 on stackoverflow
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)

    EXP_DIR = FLAGS.exp_dir

    model = Model()
    tr_ops = TrainOps(model, EXP_DIR)

    if 'train' in FLAGS.mode:
        npr.seed(int(FLAGS.seed))

    if FLAGS.mode == 'train_ERM':
        print 'Training model with standard ERM'
        tr_ops.load_exp_config()
        tr_ops.train()

    if FLAGS.mode == 'train_RDA':
        print 'Training model with RDA'
        tr_ops.load_exp_config()
        tr_ops.train(random_transf=True)

    if FLAGS.mode == 'train_RSDA':
        print 'Training model with RSDA'
        tr_ops.load_exp_config()
        tr_ops.train_search(search_algorithm='random_search')

    if FLAGS.mode == 'train_ESDA':
        print 'Training model with ESDA'
        tr_ops.load_exp_config()
        tr_ops.train_search(search_algorithm='evolution_search')

    elif FLAGS.mode == 'test_all':
        print 'Testing all'
        tr_ops.load_exp_config()
        tr_ops.test_all()

    elif FLAGS.mode == 'test_RS':
        print 'Random search'
        tr_ops.load_exp_config()
        tr_ops.test_random_search(run=str(FLAGS.run),
                                  seed=int(FLAGS.seed),
                                  no_iters=int(FLAGS.search_no_iters),
                                  string_length=int(
                                      FLAGS.transf_string_length))

    elif FLAGS.mode == 'test_ES':
        print 'Evolution search'
        tr_ops.load_exp_config()
        tr_ops.test_evolution_search(run=str(FLAGS.run),
                                     seed=int(FLAGS.seed),
                                     no_iters=int(FLAGS.search_no_iters),
                                     string_length=int(
                                         FLAGS.transf_string_length),
                                     pop_size=int(FLAGS.pop_size),
                                     mutation_rate=float(FLAGS.mutation_rate))