예제 #1
0
파일: train.py 프로젝트: aronsar/hoad
def main(args):
    if tf.test.is_gpu_available():
        print(bc.OKGREEN + bc.BOLD + '#' * 9 + ' USING GPU ' + '#' * 9 +
              bc.ENDC)
    else:
        print(bc.FAIL + bc.BOLD + '#' * 9 + ' NOT USING GPU ' + '#' * 9 +
              bc.ENDC)

    # Get agent name
    tokens = args.p.split('/')
    if args.p[-1] == '/':
        assert (tokens.pop() == '')
    dir_agent = '-'.join(tokens[-1].split('_')[:-3]) + '.save'
    print(dir_agent)

    # run this first to avoid failing after huge overhead
    model_ok, initial_epoch = model_exists(args.m, dir_agent)

    PATH_DIR_SAVE = os.path.join(args.m, dir_agent)
    PATH_DIR_CKPT = os.path.join(PATH_DIR_SAVE, 'ckpts')

    n_epoch = args.epochs
    hypers = {
        'lr': 0.00015,
        'batch_size': 128,
        'hl_activations': [ReLU, ReLU, ReLU, ReLU, ReLU, ReLU],
        'hl_sizes': [1024, 1024, 512, 512, 512, 256],
        'decay': 0.,
        'bNorm': True,
        'dropout': True,
        'regularizer': None
    }

    # checking input data format.
    if args.p.split('.')[-1] in ['hdf5', 'HDF5']:
        f = h5py_cache.File(p,
                            'r',
                            chunk_cache_mem_size=1 * 1024**3,
                            swmr=True)
        gen_tr = Gen4h5(f['X_tr'], f['Y_tr'], hypers['batch_size'], False)
        gen_va = Gen4h5(f['X_va'], f['Y_va'], hypers['batch_size'], False)
    else:
        X, Y, mask = CV(args.p)
        gen_tr = DataGenerator(X[mask], Y[mask], hypers['batch_size'])
        gen_va = DataGenerator(X[~mask], Y[~mask], 1000)

    os.makedirs(PATH_DIR_CKPT, exist_ok=True)

    # Callbacks: save best & latest models.
    callbacks = [
        ModelCheckpoint(os.path.join(PATH_DIR_SAVE, 'best.h5'),
                        monitor='val_loss',
                        verbose=1,
                        save_best_only=True,
                        save_weights_only=True,
                        mode='auto',
                        period=1),
        ModelCheckpoint(os.path.join(PATH_DIR_CKPT,
                                     '{epoch:02d}-{val_accuracy:.2f}.h5'),
                        monitor='val_loss',
                        verbose=1,
                        save_best_only=False,
                        save_weights_only=True,
                        mode='auto',
                        period=1),
        CSVLogger(os.path.join(PATH_DIR_SAVE, 'training.log'), append=True)
    ]

    m = Mlp(io_sizes=(glb.SIZE_OBS_VEC, glb.SIZE_ACT_VEC),
            out_activation=Softmax,
            loss='categorical_crossentropy',
            metrics=['accuracy'],
            **hypers,
            verbose=1)

    if model_ok:
        # continue from previously saved
        msg = "Saved model found. Resuming training."
        print(bc.OKGREEN + bc.BOLD + msg + bc.ENDC)
        h5s = os.listdir(PATH_DIR_CKPT)
        h5s.sort()
        saved_h5 = os.path.join(PATH_DIR_CKPT, h5s[-1])
        m.construct_model(saved_h5, weights_only=True)
    else:
        # create new model
        msg = "{} doesn't exist or is empty. Creating new model."
        print(bc.WARNING + bc.BOLD + msg.format(PATH_DIR_SAVE) + bc.ENDC)
        os.makedirs(PATH_DIR_CKPT, exist_ok=True)
        m.construct_model()

    m.train_model(gen_tr,
                  gen_va,
                  n_epoch=n_epoch,
                  callbacks=callbacks,
                  verbose=False,
                  workers=args.w,
                  use_mp=True,
                  max_q_size=args.q,
                  initial_epoch=initial_epoch)
예제 #2
0
        'hl_activations': [ReLU, ReLU, ReLU, ReLU],
        'hl_sizes': [512, 512, 512, 256],
        'decay': 0.,
        'bNorm': True,
        'dropout': True,
        'regularizer': None
    }

    X_tr, Y_tr = load_data(args.datapath, args.agent, num_games=10)
    X_val, Y_val = load_data(args.datapath, args.agent, num_games=30)

    m = Mlp(io_sizes=(658, 20),
            out_activation=Softmax,
            loss='categorical_crossentropy',
            metrics=['accuracy'],
            **hypers,
            verbose=0)
    m.construct_model()

    m.hist = m.model.fit(X_tr,
                         Y_tr,
                         hypers['batch_size'],
                         args.epochs,
                         validation_data=(X_val, Y_val),
                         verbose=0)

    val_acc_str = str(m.hist.history['val_accuracy'][-1])[:6]
    print("Agent %s got val acc: %s" % (args.agent, val_acc_str))
    savepath = os.path.join(args.savedir, args.agent, val_acc_str + '.h5')
    m.model.save_weights(savepath)