Ejemplo n.º 1
0
    acc /= times
    return acc, loss


def inference(model, X):  # Test Process
    model.eval()
    pred_ = model(torch.from_numpy(X).to(device))
    return pred_.cpu().data.numpy()


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not os.path.exists(args.train_dir):
        os.mkdir(args.train_dir)
    if args.is_train:
        X_train, X_test, y_train, y_test = load_cifar_4d(args.data_dir)
        X_val, y_val = X_train[40000:], y_train[40000:]
        X_train, y_train = X_train[:40000], y_train[:40000]
        cnn_model = Model(drop_rate=args.drop_rate)
        cnn_model.to(device)
        print(cnn_model)
        optimizer = optim.Adam(cnn_model.parameters(), lr=args.learning_rate)

        # model_path = os.path.join(args.train_dir, 'checkpoint_%d.pth.tar' % args.inference_version)
        # if os.path.exists(model_path):
        # 	cnn_model = torch.load(model_path)

        pre_losses = [1e18] * 3
        best_val_acc = 0.0
        for epoch in range(1, args.num_epochs + 1):
            start_time = time.time()
Ejemplo n.º 2
0
    return acc, loss, acc_list, loss_list


def inference(model, sess, X):  # Test Process
    return sess.run([model.pred_val], {model.x_: X})[0]


config = tf.ConfigProto(log_device_placement=True)
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.device
config.gpu_options.allow_growth = True

with tf.Session(config=config) as sess:
    if not os.path.exists(FLAGS.train_dir):
        os.mkdir(FLAGS.train_dir)
    if FLAGS.is_train:
        X_train, X_test, y_train, y_test = load_cifar_4d(FLAGS.data_dir)
        X_val, y_val = X_train[40000:], y_train[40000:]
        X_train, y_train = X_train[:40000], y_train[:40000]
        cnn_model = Model(dropout=FLAGS.drop_rate, batch_norm=FLAGS.is_BN)
        '''
        if tf.train.get_checkpoint_state(FLAGS.train_dir):
            cnn_model.saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))
        else:
        '''
        tf.global_variables_initializer().run()

        pre_losses = [1e18] * 3
        best_val_acc = 0.0
        print("training....")
        train_loss_list = []
        train_acc_list = []